Merge "Add @AppModeFull annotation to NsdManagerDownstreamTetheringTest" into main
diff --git a/framework-t/Android.bp b/framework-t/Android.bp
index ba0d4d9..d177ea9 100644
--- a/framework-t/Android.bp
+++ b/framework-t/Android.bp
@@ -51,7 +51,7 @@
         ":framework-connectivity-tiramisu-updatable-sources",
         ":framework-nearby-java-sources",
         ":framework-thread-sources",
-    ] + framework_remoteauth_srcs,
+    ],
     libs: [
         "unsupportedappusage",
         "app-compat-annotations",
@@ -126,7 +126,6 @@
         "enable-framework-connectivity-t-targets",
         "FlaggedApiDefaults",
     ],
-    api_srcs: framework_remoteauth_api_srcs,
     // Do not add static_libs to this library: put them in framework-connectivity instead.
     // The jarjar rules are only so that references to jarjared utils in
     // framework-connectivity-pre-jarjar match at runtime.
@@ -143,10 +142,8 @@
         "android.net",
         "android.net.nsd",
         "android.nearby",
-        "android.remoteauth",
         "com.android.connectivity",
         "com.android.nearby",
-        "com.android.remoteauth",
     ],
 
     hidden_api: {
diff --git a/remoteauth/service/Android.bp b/remoteauth/service/Android.bp
index ae5fe5c..98ed2b2 100644
--- a/remoteauth/service/Android.bp
+++ b/remoteauth/service/Android.bp
@@ -18,7 +18,7 @@
 
 filegroup {
     name: "remoteauth-service-srcs",
-    srcs: ["java/**/*.java"],
+    srcs: [],
 }
 
 // Main lib for remoteauth services.
diff --git a/remoteauth/tests/unit/Android.bp b/remoteauth/tests/unit/Android.bp
index 37c78c7..16a8242 100644
--- a/remoteauth/tests/unit/Android.bp
+++ b/remoteauth/tests/unit/Android.bp
@@ -26,7 +26,7 @@
     min_sdk_version: "31",
 
     // Include all test java files.
-    srcs: ["src/**/*.java"],
+    srcs: [],
 
     libs: [
         "android.test.base",
diff --git a/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp b/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
index ec63e41..9b1b72d 100644
--- a/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
+++ b/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
@@ -18,6 +18,7 @@
 
 #include "netdbpf/NetworkTraceHandler.h"
 
+#include <android-base/macros.h>
 #include <arpa/inet.h>
 #include <bpf/BpfUtils.h>
 #include <log/log.h>
@@ -75,9 +76,35 @@
   uint32_t bytes = 0;
 };
 
-#define AGG_FIELDS(x)                                              \
-  (x).ifindex, (x).uid, (x).tag, (x).sport, (x).dport, (x).egress, \
-      (x).ipProto, (x).tcpFlags
+BundleKey::BundleKey(const PacketTrace& pkt)
+    : ifindex(pkt.ifindex),
+      uid(pkt.uid),
+      tag(pkt.tag),
+      egress(pkt.egress),
+      ipProto(pkt.ipProto),
+      ipVersion(pkt.ipVersion) {
+  switch (ipProto) {
+    case IPPROTO_TCP:
+      tcpFlags = pkt.tcpFlags;
+      FALLTHROUGH_INTENDED;
+    case IPPROTO_DCCP:
+    case IPPROTO_UDP:
+    case IPPROTO_UDPLITE:
+    case IPPROTO_SCTP:
+      localPort = ntohs(pkt.egress ? pkt.sport : pkt.dport);
+      remotePort = ntohs(pkt.egress ? pkt.dport : pkt.sport);
+      break;
+    case IPPROTO_ICMP:
+    case IPPROTO_ICMPV6:
+      icmpType = ntohs(pkt.sport);
+      icmpCode = ntohs(pkt.dport);
+      break;
+  }
+}
+
+#define AGG_FIELDS(x)                                                    \
+  (x).ifindex, (x).uid, (x).tag, (x).egress, (x).ipProto, (x).ipVersion, \
+      (x).tcpFlags, (x).localPort, (x).remotePort, (x).icmpType, (x).icmpCode
 
 std::size_t BundleHash::operator()(const BundleKey& a) const {
   std::size_t seed = 0;
@@ -179,7 +206,7 @@
       dst->set_timestamp(pkt.timestampNs);
       auto* event = dst->set_network_packet();
       event->set_length(pkt.length);
-      Fill(pkt, event);
+      Fill(BundleKey(pkt), event);
     }
     return;
   }
@@ -187,14 +214,13 @@
   uint64_t minTs = std::numeric_limits<uint64_t>::max();
   std::unordered_map<BundleKey, BundleDetails, BundleHash, BundleEq> bundles;
   for (const PacketTrace& pkt : packets) {
-    BundleKey key = pkt;
+    BundleKey key(pkt);
 
     // Dropping fields should remove them from the output and remove them from
-    // the aggregation key. In order to do the latter without changing the hash
-    // function, set the dropped fields to zero.
-    if (mDropTcpFlags) key.tcpFlags = 0;
-    if (mDropLocalPort) (key.egress ? key.sport : key.dport) = 0;
-    if (mDropRemotePort) (key.egress ? key.dport : key.sport) = 0;
+    // the aggregation key. Reset the optionals to indicate omission.
+    if (mDropTcpFlags) key.tcpFlags.reset();
+    if (mDropLocalPort) key.localPort.reset();
+    if (mDropRemotePort) key.remotePort.reset();
 
     minTs = std::min(minTs, pkt.timestampNs);
 
@@ -245,22 +271,18 @@
   }
 }
 
-void NetworkTraceHandler::Fill(const PacketTrace& src,
+void NetworkTraceHandler::Fill(const BundleKey& src,
                                NetworkPacketEvent* event) {
   event->set_direction(src.egress ? TrafficDirection::DIR_EGRESS
                                   : TrafficDirection::DIR_INGRESS);
   event->set_uid(src.uid);
   event->set_tag(src.tag);
 
-  if (!mDropLocalPort) {
-    event->set_local_port(ntohs(src.egress ? src.sport : src.dport));
-  }
-  if (!mDropRemotePort) {
-    event->set_remote_port(ntohs(src.egress ? src.dport : src.sport));
-  }
-  if (!mDropTcpFlags) {
-    event->set_tcp_flags(src.tcpFlags);
-  }
+  if (src.tcpFlags.has_value()) event->set_tcp_flags(*src.tcpFlags);
+  if (src.localPort.has_value()) event->set_local_port(*src.localPort);
+  if (src.remotePort.has_value()) event->set_remote_port(*src.remotePort);
+  if (src.icmpType.has_value()) event->set_icmp_type(*src.icmpType);
+  if (src.icmpCode.has_value()) event->set_icmp_code(*src.icmpCode);
 
   event->set_ip_proto(src.ipProto);
 
diff --git a/service-t/native/libs/libnetworkstats/NetworkTraceHandlerTest.cpp b/service-t/native/libs/libnetworkstats/NetworkTraceHandlerTest.cpp
index f2c1a86..0c4f049 100644
--- a/service-t/native/libs/libnetworkstats/NetworkTraceHandlerTest.cpp
+++ b/service-t/native/libs/libnetworkstats/NetworkTraceHandlerTest.cpp
@@ -113,7 +113,7 @@
           .length = 100,
           .uid = 10,
           .tag = 123,
-          .ipProto = 6,
+          .ipProto = IPPROTO_TCP,
           .tcpFlags = 1,
       },
   };
@@ -138,12 +138,14 @@
           .sport = htons(8080),
           .dport = htons(443),
           .egress = true,
+          .ipProto = IPPROTO_TCP,
       },
       PacketTrace{
           .timestampNs = 2,
           .sport = htons(443),
           .dport = htons(8080),
           .egress = false,
+          .ipProto = IPPROTO_TCP,
       },
   };
 
@@ -161,6 +163,42 @@
               TrafficDirection::DIR_INGRESS);
 }
 
+TEST_F(NetworkTraceHandlerTest, WriteIcmpTypeAndCode) {
+  std::vector<PacketTrace> input = {
+      PacketTrace{
+          .timestampNs = 1,
+          .sport = htons(11),  // type
+          .dport = htons(22),  // code
+          .egress = true,
+          .ipProto = IPPROTO_ICMP,
+      },
+      PacketTrace{
+          .timestampNs = 2,
+          .sport = htons(33),  // type
+          .dport = htons(44),  // code
+          .egress = false,
+          .ipProto = IPPROTO_ICMPV6,
+      },
+  };
+
+  std::vector<TracePacket> events;
+  ASSERT_TRUE(TraceAndSortPackets(input, &events));
+
+  ASSERT_EQ(events.size(), 2);
+  EXPECT_FALSE(events[0].network_packet().has_local_port());
+  EXPECT_FALSE(events[0].network_packet().has_remote_port());
+  EXPECT_THAT(events[0].network_packet().icmp_type(), 11);
+  EXPECT_THAT(events[0].network_packet().icmp_code(), 22);
+  EXPECT_THAT(events[0].network_packet().direction(),
+              TrafficDirection::DIR_EGRESS);
+  EXPECT_FALSE(events[1].network_packet().local_port());
+  EXPECT_FALSE(events[1].network_packet().remote_port());
+  EXPECT_THAT(events[1].network_packet().icmp_type(), 33);
+  EXPECT_THAT(events[1].network_packet().icmp_code(), 44);
+  EXPECT_THAT(events[1].network_packet().direction(),
+              TrafficDirection::DIR_INGRESS);
+}
+
 TEST_F(NetworkTraceHandlerTest, BasicBundling) {
   // TODO: remove this once bundling becomes default. Until then, set arbitrary
   // aggregation threshold to enable bundling.
@@ -168,12 +206,12 @@
   config.set_aggregation_threshold(10);
 
   std::vector<PacketTrace> input = {
-      PacketTrace{.uid = 123, .timestampNs = 2, .length = 200},
-      PacketTrace{.uid = 123, .timestampNs = 1, .length = 100},
-      PacketTrace{.uid = 123, .timestampNs = 4, .length = 300},
+      PacketTrace{.timestampNs = 2, .length = 200, .uid = 123},
+      PacketTrace{.timestampNs = 1, .length = 100, .uid = 123},
+      PacketTrace{.timestampNs = 4, .length = 300, .uid = 123},
 
-      PacketTrace{.uid = 456, .timestampNs = 2, .length = 400},
-      PacketTrace{.uid = 456, .timestampNs = 4, .length = 100},
+      PacketTrace{.timestampNs = 2, .length = 400, .uid = 456},
+      PacketTrace{.timestampNs = 4, .length = 100, .uid = 456},
   };
 
   std::vector<TracePacket> events;
@@ -203,12 +241,12 @@
   config.set_aggregation_threshold(3);
 
   std::vector<PacketTrace> input = {
-      PacketTrace{.uid = 123, .timestampNs = 2, .length = 200},
-      PacketTrace{.uid = 123, .timestampNs = 1, .length = 100},
-      PacketTrace{.uid = 123, .timestampNs = 4, .length = 300},
+      PacketTrace{.timestampNs = 2, .length = 200, .uid = 123},
+      PacketTrace{.timestampNs = 1, .length = 100, .uid = 123},
+      PacketTrace{.timestampNs = 4, .length = 300, .uid = 123},
 
-      PacketTrace{.uid = 456, .timestampNs = 2, .length = 400},
-      PacketTrace{.uid = 456, .timestampNs = 4, .length = 100},
+      PacketTrace{.timestampNs = 2, .length = 400, .uid = 456},
+      PacketTrace{.timestampNs = 4, .length = 100, .uid = 456},
   };
 
   std::vector<TracePacket> events;
@@ -239,12 +277,17 @@
   __be16 b = htons(10001);
   std::vector<PacketTrace> input = {
       // Recall that local is `src` for egress and `dst` for ingress.
-      PacketTrace{.timestampNs = 1, .length = 2, .egress = true, .sport = a},
-      PacketTrace{.timestampNs = 2, .length = 4, .egress = false, .dport = a},
-      PacketTrace{.timestampNs = 3, .length = 6, .egress = true, .sport = b},
-      PacketTrace{.timestampNs = 4, .length = 8, .egress = false, .dport = b},
+      PacketTrace{.timestampNs = 1, .length = 2, .sport = a, .egress = true},
+      PacketTrace{.timestampNs = 2, .length = 4, .dport = a, .egress = false},
+      PacketTrace{.timestampNs = 3, .length = 6, .sport = b, .egress = true},
+      PacketTrace{.timestampNs = 4, .length = 8, .dport = b, .egress = false},
   };
 
+  // Set common fields.
+  for (PacketTrace& pkt : input) {
+    pkt.ipProto = IPPROTO_TCP;
+  }
+
   std::vector<TracePacket> events;
   ASSERT_TRUE(TraceAndSortPackets(input, &events, config));
   ASSERT_EQ(events.size(), 2);
@@ -274,12 +317,17 @@
   __be16 b = htons(80);
   std::vector<PacketTrace> input = {
       // Recall that remote is `dst` for egress and `src` for ingress.
-      PacketTrace{.timestampNs = 1, .length = 2, .egress = true, .dport = a},
-      PacketTrace{.timestampNs = 2, .length = 4, .egress = false, .sport = a},
-      PacketTrace{.timestampNs = 3, .length = 6, .egress = true, .dport = b},
-      PacketTrace{.timestampNs = 4, .length = 8, .egress = false, .sport = b},
+      PacketTrace{.timestampNs = 1, .length = 2, .dport = a, .egress = true},
+      PacketTrace{.timestampNs = 2, .length = 4, .sport = a, .egress = false},
+      PacketTrace{.timestampNs = 3, .length = 6, .dport = b, .egress = true},
+      PacketTrace{.timestampNs = 4, .length = 8, .sport = b, .egress = false},
   };
 
+  // Set common fields.
+  for (PacketTrace& pkt : input) {
+    pkt.ipProto = IPPROTO_TCP;
+  }
+
   std::vector<TracePacket> events;
   ASSERT_TRUE(TraceAndSortPackets(input, &events, config));
   ASSERT_EQ(events.size(), 2);
@@ -306,12 +354,17 @@
   config.set_aggregation_threshold(10);
 
   std::vector<PacketTrace> input = {
-      PacketTrace{.timestampNs = 1, .uid = 123, .length = 1, .tcpFlags = 1},
-      PacketTrace{.timestampNs = 2, .uid = 123, .length = 2, .tcpFlags = 2},
-      PacketTrace{.timestampNs = 3, .uid = 456, .length = 3, .tcpFlags = 1},
-      PacketTrace{.timestampNs = 4, .uid = 456, .length = 4, .tcpFlags = 2},
+      PacketTrace{.timestampNs = 1, .length = 1, .uid = 123, .tcpFlags = 1},
+      PacketTrace{.timestampNs = 2, .length = 2, .uid = 123, .tcpFlags = 2},
+      PacketTrace{.timestampNs = 3, .length = 3, .uid = 456, .tcpFlags = 1},
+      PacketTrace{.timestampNs = 4, .length = 4, .uid = 456, .tcpFlags = 2},
   };
 
+  // Set common fields.
+  for (PacketTrace& pkt : input) {
+    pkt.ipProto = IPPROTO_TCP;
+  }
+
   std::vector<TracePacket> events;
   ASSERT_TRUE(TraceAndSortPackets(input, &events, config));
 
diff --git a/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTraceHandler.h b/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTraceHandler.h
index bc10e68..6bf186a 100644
--- a/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTraceHandler.h
+++ b/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTraceHandler.h
@@ -30,15 +30,33 @@
 namespace android {
 namespace bpf {
 
-// BundleKeys are PacketTraces where timestamp and length are ignored.
-using BundleKey = PacketTrace;
+// BundleKey encodes a PacketTrace minus timestamp and length. The key should
+// match many packets over time for interning. For convenience, sport/dport
+// are parsed here as either local/remote port or icmp type/code.
+struct BundleKey {
+  explicit BundleKey(const PacketTrace& pkt);
 
-// BundleKeys are hashed using all fields except timestamp/length.
+  uint32_t ifindex;
+  uint32_t uid;
+  uint32_t tag;
+
+  bool egress;
+  uint8_t ipProto;
+  uint8_t ipVersion;
+
+  std::optional<uint8_t> tcpFlags;
+  std::optional<uint16_t> localPort;
+  std::optional<uint16_t> remotePort;
+  std::optional<uint8_t> icmpType;
+  std::optional<uint8_t> icmpCode;
+};
+
+// BundleKeys are hashed using a simple hash combine.
 struct BundleHash {
   std::size_t operator()(const BundleKey& a) const;
 };
 
-// BundleKeys are equal if all fields except timestamp/length are equal.
+// BundleKeys are equal if all fields are equal.
 struct BundleEq {
   bool operator()(const BundleKey& a, const BundleKey& b) const;
 };
@@ -84,13 +102,13 @@
              NetworkTraceHandler::TraceContext& ctx);
 
  private:
-  // Convert a PacketTrace into a Perfetto trace packet.
-  void Fill(const PacketTrace& src,
+  // Fills in contextual information from a bundle without interning.
+  void Fill(const BundleKey& src,
             ::perfetto::protos::pbzero::NetworkPacketEvent* event);
 
   // Fills in contextual information either inline or via interning.
   ::perfetto::protos::pbzero::NetworkPacketBundle* FillWithInterning(
-      NetworkTraceState* state, const BundleKey& key,
+      NetworkTraceState* state, const BundleKey& src,
       ::perfetto::protos::pbzero::TracePacket* dst);
 
   static internal::NetworkTracePoller sPoller;
diff --git a/service-t/src/com/android/server/ConnectivityServiceInitializer.java b/service-t/src/com/android/server/ConnectivityServiceInitializer.java
index 003ec8c..1ac2f6e 100644
--- a/service-t/src/com/android/server/ConnectivityServiceInitializer.java
+++ b/service-t/src/com/android/server/ConnectivityServiceInitializer.java
@@ -28,7 +28,6 @@
 import com.android.server.ethernet.EthernetService;
 import com.android.server.ethernet.EthernetServiceImpl;
 import com.android.server.nearby.NearbyService;
-import com.android.server.remoteauth.RemoteAuthService;
 import com.android.server.thread.ThreadNetworkService;
 
 /**
@@ -43,7 +42,6 @@
     private final NsdService mNsdService;
     private final NearbyService mNearbyService;
     private final EthernetServiceImpl mEthernetServiceImpl;
-    private final RemoteAuthService mRemoteAuthService;
     private final ThreadNetworkService mThreadNetworkService;
 
     public ConnectivityServiceInitializer(Context context) {
@@ -56,7 +54,6 @@
         mConnectivityNative = createConnectivityNativeService(context);
         mNsdService = createNsdService(context);
         mNearbyService = createNearbyService(context);
-        mRemoteAuthService = createRemoteAuthService(context);
         mThreadNetworkService = createThreadNetworkService(context);
     }
 
@@ -94,12 +91,6 @@
                     /* allowIsolated= */ false);
         }
 
-        if (mRemoteAuthService != null) {
-            Log.i(TAG, "Registering " + RemoteAuthService.SERVICE_NAME);
-            publishBinderService(RemoteAuthService.SERVICE_NAME, mRemoteAuthService,
-                    /* allowIsolated= */ false);
-        }
-
         if (mThreadNetworkService != null) {
             Log.i(TAG, "Registering " + ThreadNetworkManager.SERVICE_NAME);
             publishBinderService(ThreadNetworkManager.SERVICE_NAME, mThreadNetworkService,
@@ -164,19 +155,6 @@
         }
     }
 
-    /** Return RemoteAuth service instance */
-    private RemoteAuthService createRemoteAuthService(final Context context) {
-        if (!SdkLevel.isAtLeastV()) return null;
-        try {
-            return new RemoteAuthService(context);
-        } catch (UnsupportedOperationException e) {
-            // RemoteAuth is not yet supported in all branches
-            // TODO: remove catch clause when it is available.
-            Log.i(TAG, "Skipping unsupported service " + RemoteAuthService.SERVICE_NAME);
-            return null;
-        }
-    }
-
     /**
      * Return EthernetServiceImpl instance or null if current SDK is lower than T or Ethernet
      * service isn't necessary.
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index 85507f6..3ae3e2d 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -97,14 +97,12 @@
 import static android.system.OsConstants.ETH_P_ALL;
 import static android.system.OsConstants.IPPROTO_TCP;
 import static android.system.OsConstants.IPPROTO_UDP;
-
 import static com.android.net.module.util.NetworkMonitorUtils.isPrivateDnsValidationRequired;
 import static com.android.net.module.util.PermissionUtils.checkAnyPermissionOf;
 import static com.android.net.module.util.PermissionUtils.enforceAnyPermissionOf;
 import static com.android.net.module.util.PermissionUtils.enforceNetworkStackPermission;
 import static com.android.net.module.util.PermissionUtils.enforceNetworkStackPermissionOr;
 import static com.android.server.ConnectivityStatsLog.CONNECTIVITY_STATE_SAMPLE;
-
 import static java.util.Map.Entry;
 
 import android.Manifest;
@@ -10614,6 +10612,16 @@
                 err.getFileDescriptor(), args);
     }
 
+    private Boolean parseBooleanArgument(final String arg) {
+        if ("true".equals(arg)) {
+            return true;
+        } else if ("false".equals(arg)) {
+            return false;
+        } else {
+            return null;
+        }
+    }
+
     private class ShellCmd extends BasicShellCommandHandler {
         @Override
         public int onCommand(String cmd) {
@@ -10643,6 +10651,54 @@
                             onHelp();
                             return -1;
                         }
+                    case "set-chain3-enabled": {
+                        final Boolean enabled = parseBooleanArgument(getNextArg());
+                        if (null == enabled) {
+                            onHelp();
+                            return -1;
+                        }
+                        Log.i(TAG, (enabled ? "En" : "Dis") + "abled FIREWALL_CHAIN_OEM_DENY_3");
+                        setFirewallChainEnabled(ConnectivityManager.FIREWALL_CHAIN_OEM_DENY_3,
+                                enabled);
+                        return 0;
+                    }
+                    case "get-chain3-enabled": {
+                        final boolean chainEnabled = getFirewallChainEnabled(
+                                ConnectivityManager.FIREWALL_CHAIN_OEM_DENY_3);
+                        pw.println("chain:" + (chainEnabled ? "enabled" : "disabled"));
+                        return 0;
+                    }
+                    case "set-package-networking-enabled": {
+                        final Boolean enabled = parseBooleanArgument(getNextArg());
+                        final String packageName = getNextArg();
+                        if (null == enabled || null == packageName) {
+                            onHelp();
+                            return -1;
+                        }
+                        // Throws NameNotFound if the package doesn't exist.
+                        final int appId = setPackageFirewallRule(
+                                ConnectivityManager.FIREWALL_CHAIN_OEM_DENY_3,
+                                packageName, enabled ? FIREWALL_RULE_DEFAULT : FIREWALL_RULE_DENY);
+                        final String msg = (enabled ? "Enabled" : "Disabled")
+                                + " networking for " + packageName + ", appId " + appId;
+                        Log.i(TAG, msg);
+                        pw.println(msg);
+                        return 0;
+                    }
+                    case "get-package-networking-enabled": {
+                        final String packageName = getNextArg();
+                        final int rule = getPackageFirewallRule(
+                                ConnectivityManager.FIREWALL_CHAIN_OEM_DENY_3, packageName);
+                        if (FIREWALL_RULE_ALLOW == rule || FIREWALL_RULE_DEFAULT == rule) {
+                            pw.println(packageName + ":" + "allow");
+                        } else if (FIREWALL_RULE_DENY == rule) {
+                            pw.println(packageName + ":" + "deny");
+                        } else {
+                            throw new IllegalStateException("Unknown rule " + rule + " for package "
+                                    + packageName);
+                        }
+                        return 0;
+                    }
                     case "reevaluate":
                         // Usage : adb shell cmd connectivity reevaluate <netId>
                         // If netId is omitted, then reevaluate the default network
@@ -10683,6 +10739,15 @@
             pw.println("    Turn airplane mode on or off.");
             pw.println("  airplane-mode");
             pw.println("    Get airplane mode.");
+            pw.println("  set-chain3-enabled [true|false]");
+            pw.println("    Enable or disable FIREWALL_CHAIN_OEM_DENY_3 for debugging.");
+            pw.println("  get-chain3-enabled");
+            pw.println("    Returns whether FIREWALL_CHAIN_OEM_DENY_3 is enabled.");
+            pw.println("  set-package-networking-enabled [true|false] [package name]");
+            pw.println("    Set the deny bit in FIREWALL_CHAIN_OEM_DENY_3 to package. This has\n"
+                    + "    no effect if the chain is disabled.");
+            pw.println("  get-package-networking-enabled [package name]");
+            pw.println("    Get the deny bit in FIREWALL_CHAIN_OEM_DENY_3 for package.");
         }
     }
 
@@ -11353,7 +11418,7 @@
         public void onInterfaceLinkStateChanged(@NonNull String iface, boolean up) {
             mHandler.post(() -> {
                 for (NetworkAgentInfo nai : mNetworkAgentInfos) {
-                    nai.clatd.interfaceLinkStateChanged(iface, up);
+                    nai.clatd.handleInterfaceLinkStateChanged(iface, up);
                 }
             });
         }
@@ -11362,7 +11427,7 @@
         public void onInterfaceRemoved(@NonNull String iface) {
             mHandler.post(() -> {
                 for (NetworkAgentInfo nai : mNetworkAgentInfos) {
-                    nai.clatd.interfaceRemoved(iface);
+                    nai.clatd.handleInterfaceRemoved(iface);
                 }
             });
         }
@@ -12418,6 +12483,21 @@
         }
     }
 
+    private int setPackageFirewallRule(final int chain, final String packageName, final int rule)
+            throws PackageManager.NameNotFoundException {
+        final PackageManager pm = mContext.getPackageManager();
+        final int appId = UserHandle.getAppId(pm.getPackageUid(packageName, 0 /* flags */));
+        if (appId < Process.FIRST_APPLICATION_UID) {
+            throw new RuntimeException("Can't set package firewall rule for system app "
+                    + packageName + " with appId " + appId);
+        }
+        for (final UserHandle uh : mUserManager.getUserHandles(false /* excludeDying */)) {
+            final int uid = uh.getUid(appId);
+            setUidFirewallRule(chain, uid, rule);
+        }
+        return appId;
+    }
+
     @Override
     public void setUidFirewallRule(final int chain, final int uid, final int rule) {
         enforceNetworkStackOrSettingsPermission();
@@ -12436,6 +12516,13 @@
         }
     }
 
+    private int getPackageFirewallRule(final int chain, final String packageName)
+            throws PackageManager.NameNotFoundException {
+        final PackageManager pm = mContext.getPackageManager();
+        final int appId = UserHandle.getAppId(pm.getPackageUid(packageName, 0 /* flags */));
+        return getUidFirewallRule(chain, appId);
+    }
+
     @Override
     public int getUidFirewallRule(final int chain, final int uid) {
         enforceNetworkStackOrSettingsPermission();
diff --git a/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java b/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
index 4325763..88aa329 100644
--- a/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
+++ b/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
@@ -35,6 +35,7 @@
 import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
 import android.util.Log;
+import android.util.SparseIntArray;
 
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
@@ -63,7 +64,7 @@
     private final TelephonyManagerShim mTelephonyManagerShim;
     private final TelephonyManager mTelephonyManager;
     @GuardedBy("mLock")
-    private int[] mCarrierServiceUid;
+    private final SparseIntArray mCarrierServiceUid = new SparseIntArray(2 /* initialCapacity */);
     @GuardedBy("mLock")
     private int mModemCount = 0;
     private final Object mLock = new Object();
@@ -75,7 +76,7 @@
 
     public CarrierPrivilegeAuthenticator(@NonNull final Context c,
             @NonNull final TelephonyManager t,
-            @NonNull final TelephonyManagerShimImpl telephonyManagerShim) {
+            @NonNull final TelephonyManagerShim telephonyManagerShim) {
         mContext = c;
         mTelephonyManager = t;
         mTelephonyManagerShim = telephonyManagerShim;
@@ -91,17 +92,7 @@
 
     public CarrierPrivilegeAuthenticator(@NonNull final Context c,
             @NonNull final TelephonyManager t) {
-        mContext = c;
-        mTelephonyManager = t;
-        mTelephonyManagerShim = TelephonyManagerShimImpl.newInstance(mTelephonyManager);
-        mThread = new HandlerThread(TAG);
-        mThread.start();
-        mHandler = new Handler(mThread.getLooper()) {};
-        synchronized (mLock) {
-            mModemCount = mTelephonyManager.getActiveModemCount();
-            registerForCarrierChanges();
-            updateCarrierServiceUid();
-        }
+        this(c, t, TelephonyManagerShimImpl.newInstance(t));
     }
 
     /**
@@ -233,9 +224,9 @@
     @VisibleForTesting
     void updateCarrierServiceUid() {
         synchronized (mLock) {
-            mCarrierServiceUid = new int[mModemCount];
+            mCarrierServiceUid.clear();
             for (int i = 0; i < mModemCount; i++) {
-                mCarrierServiceUid[i] = getCarrierServicePackageUidForSlot(i);
+                mCarrierServiceUid.put(i, getCarrierServicePackageUidForSlot(i));
             }
         }
     }
@@ -244,11 +235,8 @@
     int getCarrierServiceUidForSubId(int subId) {
         final int slotId = getSlotIndex(subId);
         synchronized (mLock) {
-            if (slotId != SubscriptionManager.INVALID_SIM_SLOT_INDEX && slotId < mModemCount) {
-                return mCarrierServiceUid[slotId];
-            }
+            return mCarrierServiceUid.get(slotId, Process.INVALID_UID);
         }
-        return Process.INVALID_UID;
     }
 
     @VisibleForTesting
diff --git a/service/src/com/android/server/connectivity/Nat464Xlat.java b/service/src/com/android/server/connectivity/Nat464Xlat.java
index f9e07fd..065922d 100644
--- a/service/src/com/android/server/connectivity/Nat464Xlat.java
+++ b/service/src/com/android/server/connectivity/Nat464Xlat.java
@@ -483,8 +483,9 @@
 
     /**
      * Adds stacked link on base link and transitions to RUNNING state.
+     * Must be called on the handler thread.
      */
-    private void handleInterfaceLinkStateChanged(String iface, boolean up) {
+    public void handleInterfaceLinkStateChanged(String iface, boolean up) {
         // TODO: if we call start(), then stop(), then start() again, and the
         // interfaceLinkStateChanged notification for the first start is delayed past the first
         // stop, then the code becomes out of sync with system state and will behave incorrectly.
@@ -499,6 +500,7 @@
         // Once this code is converted to StateMachine, it will be possible to use deferMessage to
         // ensure it stays in STARTING state until the interfaceLinkStateChanged notification fires,
         // and possibly use a timeout (or provide some guarantees at the lower layer) to address #1.
+        ensureRunningOnHandlerThread();
         if (!isStarting() || !up || !Objects.equals(mIface, iface)) {
             return;
         }
@@ -519,8 +521,10 @@
 
     /**
      * Removes stacked link on base link and transitions to IDLE state.
+     * Must be called on the handler thread.
      */
-    private void handleInterfaceRemoved(String iface) {
+    public void handleInterfaceRemoved(String iface) {
+        ensureRunningOnHandlerThread();
         if (!Objects.equals(mIface, iface)) {
             return;
         }
@@ -536,14 +540,6 @@
         stop();
     }
 
-    public void interfaceLinkStateChanged(String iface, boolean up) {
-        mNetwork.handler().post(() -> { handleInterfaceLinkStateChanged(iface, up); });
-    }
-
-    public void interfaceRemoved(String iface) {
-        mNetwork.handler().post(() -> handleInterfaceRemoved(iface));
-    }
-
     /**
      * Translate the input v4 address to v6 clat address.
      */
diff --git a/service/src/com/android/server/connectivity/NetworkDiagnostics.java b/service/src/com/android/server/connectivity/NetworkDiagnostics.java
index e1e2585..3db37e5 100644
--- a/service/src/com/android/server/connectivity/NetworkDiagnostics.java
+++ b/service/src/com/android/server/connectivity/NetworkDiagnostics.java
@@ -340,8 +340,9 @@
     @TargetApi(Build.VERSION_CODES.S)
     private int getMtuForTarget(InetAddress target) {
         final int family = target instanceof Inet4Address ? AF_INET : AF_INET6;
+        FileDescriptor socket = null;
         try {
-            final FileDescriptor socket = Os.socket(family, SOCK_DGRAM, 0);
+            socket = Os.socket(family, SOCK_DGRAM, 0);
             mNetwork.bindSocket(socket);
             Os.connect(socket, target, 0);
             if (family == AF_INET) {
@@ -352,6 +353,8 @@
         } catch (ErrnoException | IOException e) {
             Log.e(TAG, "Can't get MTU for destination " + target, e);
             return -1;
+        } finally {
+            IoUtils.closeQuietly(socket);
         }
     }
 
diff --git a/staticlibs/native/bpf_headers/include/bpf/BpfMap.h b/staticlibs/native/bpf_headers/include/bpf/BpfMap.h
index 847083e..3be7067 100644
--- a/staticlibs/native/bpf_headers/include/bpf/BpfMap.h
+++ b/staticlibs/native/bpf_headers/include/bpf/BpfMap.h
@@ -18,10 +18,10 @@
 
 #include <linux/bpf.h>
 
+#include <android/log.h>
 #include <android-base/result.h>
 #include <android-base/stringprintf.h>
 #include <android-base/unique_fd.h>
-#include <utils/Log.h>
 
 #include "BpfSyscallWrappers.h"
 #include "bpf/BpfUtils.h"
diff --git a/staticlibs/testutils/host/com/android/testutils/ConnectivityTestTargetPreparer.kt b/staticlibs/testutils/host/com/android/testutils/ConnectivityTestTargetPreparer.kt
index 3fc74aa..eb94781 100644
--- a/staticlibs/testutils/host/com/android/testutils/ConnectivityTestTargetPreparer.kt
+++ b/staticlibs/testutils/host/com/android/testutils/ConnectivityTestTargetPreparer.kt
@@ -32,6 +32,10 @@
 private const val CONNECTIVITY_CHECK_RUNNER_NAME = "androidx.test.runner.AndroidJUnitRunner"
 private const val IGNORE_CONN_CHECK_OPTION = "ignore-connectivity-check"
 
+// The default updater package names, which might be updating packages while the CTS
+// are running
+private val UPDATER_PKGS = arrayOf("com.google.android.gms", "com.android.vending")
+
 /**
  * A target preparer that sets up and verifies a device for connectivity tests.
  *
@@ -45,35 +49,42 @@
     @Option(name = IGNORE_CONN_CHECK_OPTION,
             description = "Disables the check for mobile data and wifi")
     private var ignoreConnectivityCheck = false
+    // The default value is never used, but false is a reasonable default
+    private var originalTestChainEnabled = false
+    private val originalUpdaterPkgsStatus = HashMap<String, Boolean>()
 
-    override fun setUp(testInformation: TestInformation) {
+    override fun setUp(testInfo: TestInformation) {
         if (isDisabled) return
-        disableGmsUpdate(testInformation)
-        runPreparerApk(testInformation)
+        disableGmsUpdate(testInfo)
+        originalTestChainEnabled = getTestChainEnabled(testInfo)
+        originalUpdaterPkgsStatus.putAll(getUpdaterPkgsStatus(testInfo))
+        setUpdaterNetworkingEnabled(testInfo, enableChain = true,
+                enablePkgs = UPDATER_PKGS.associateWith { false })
+        runPreparerApk(testInfo)
     }
 
-    private fun runPreparerApk(testInformation: TestInformation) {
+    private fun runPreparerApk(testInfo: TestInformation) {
         installer.setCleanApk(true)
         installer.addTestFileName(CONNECTIVITY_CHECKER_APK)
         installer.setShouldGrantPermission(true)
-        installer.setUp(testInformation)
+        installer.setUp(testInfo)
 
         val runner = DefaultRemoteAndroidTestRunner(
                 CONNECTIVITY_PKG_NAME,
                 CONNECTIVITY_CHECK_RUNNER_NAME,
-                testInformation.device.iDevice)
+                testInfo.device.iDevice)
         runner.runOptions = "--no-hidden-api-checks"
 
         val receiver = CollectingTestListener()
-        if (!testInformation.device.runInstrumentationTests(runner, receiver)) {
+        if (!testInfo.device.runInstrumentationTests(runner, receiver)) {
             throw TargetSetupError("Device state check failed to complete",
-                    testInformation.device.deviceDescriptor)
+                    testInfo.device.deviceDescriptor)
         }
 
         val runResult = receiver.currentRunResults
         if (runResult.isRunFailure) {
             throw TargetSetupError("Failed to check device state before the test: " +
-                    runResult.runFailureMessage, testInformation.device.deviceDescriptor)
+                    runResult.runFailureMessage, testInfo.device.deviceDescriptor)
         }
 
         val ignoredTestClasses = mutableSetOf<String>()
@@ -92,25 +103,50 @@
         if (errorMsg.isBlank()) return
 
         throw TargetSetupError("Device setup checks failed. Check the test bench: \n$errorMsg",
-                testInformation.device.deviceDescriptor)
+                testInfo.device.deviceDescriptor)
     }
 
-    private fun disableGmsUpdate(testInformation: TestInformation) {
+    private fun disableGmsUpdate(testInfo: TestInformation) {
         // This will be a no-op on devices without root (su) or not using gservices, but that's OK.
-        testInformation.device.executeShellCommand("su 0 am broadcast " +
+        testInfo.exec("su 0 am broadcast " +
                 "-a com.google.gservices.intent.action.GSERVICES_OVERRIDE " +
                 "-e finsky.play_services_auto_update_enabled false")
     }
 
-    private fun clearGmsUpdateOverride(testInformation: TestInformation) {
-        testInformation.device.executeShellCommand("su 0 am broadcast " +
+    private fun clearGmsUpdateOverride(testInfo: TestInformation) {
+        testInfo.exec("su 0 am broadcast " +
                 "-a com.google.gservices.intent.action.GSERVICES_OVERRIDE " +
                 "--esn finsky.play_services_auto_update_enabled")
     }
 
-    override fun tearDown(testInformation: TestInformation, e: Throwable?) {
+    private fun setUpdaterNetworkingEnabled(
+            testInfo: TestInformation,
+            enableChain: Boolean,
+            enablePkgs: Map<String, Boolean>
+    ) {
+        // Build.VERSION_CODES.S = 31 where this is not available, then do nothing.
+        if (testInfo.device.getApiLevel() < 31) return
+        testInfo.exec("cmd connectivity set-chain3-enabled $enableChain")
+        enablePkgs.forEach { (pkg, allow) ->
+            testInfo.exec("cmd connectivity set-package-networking-enabled $pkg $allow")
+        }
+    }
+
+    private fun getTestChainEnabled(testInfo: TestInformation) =
+            testInfo.exec("cmd connectivity get-chain3-enabled").contains("chain:enabled")
+
+    private fun getUpdaterPkgsStatus(testInfo: TestInformation) =
+            UPDATER_PKGS.associateWith { pkg ->
+                !testInfo.exec("cmd connectivity get-package-networking-enabled $pkg")
+                        .contains(":deny")
+            }
+
+    override fun tearDown(testInfo: TestInformation, e: Throwable?) {
         if (isTearDownDisabled) return
-        installer.tearDown(testInformation, e)
-        clearGmsUpdateOverride(testInformation)
+        installer.tearDown(testInfo, e)
+        setUpdaterNetworkingEnabled(testInfo,
+                enableChain = originalTestChainEnabled,
+                enablePkgs = originalUpdaterPkgsStatus)
+        clearGmsUpdateOverride(testInfo)
     }
 }
diff --git a/staticlibs/testutils/host/com/android/testutils/DisableConfigSyncTargetPreparer.kt b/staticlibs/testutils/host/com/android/testutils/DisableConfigSyncTargetPreparer.kt
index 63f05a6..bc00f3c 100644
--- a/staticlibs/testutils/host/com/android/testutils/DisableConfigSyncTargetPreparer.kt
+++ b/staticlibs/testutils/host/com/android/testutils/DisableConfigSyncTargetPreparer.kt
@@ -58,4 +58,4 @@
     }
 }
 
-private fun TestInformation.exec(cmd: String) = this.device.executeShellCommand(cmd)
\ No newline at end of file
+fun TestInformation.exec(cmd: String) = this.device.executeShellCommand(cmd)
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index 59aefa5..d2c9481 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -2723,7 +2723,8 @@
             // the network with the TEST transport. Also wait for validation here, in case there
             // is a bug that's only visible when the network is validated.
             setWifiMeteredStatusAndWait(ssid, true /* isMetered */, true /* waitForValidation */);
-            defaultCallback.expect(CallbackEntry.LOST, wifiNetwork, NETWORK_CALLBACK_TIMEOUT_MS);
+            defaultCallback.eventuallyExpect(CallbackEntry.LOST, NETWORK_CALLBACK_TIMEOUT_MS,
+                    l -> l.getNetwork().equals(wifiNetwork));
             waitForAvailable(defaultCallback, tnt.getNetwork());
             // Depending on if this device has cellular connectivity or not, multiple available
             // callbacks may be received. Eventually, metered Wi-Fi should be the final available
diff --git a/tests/unit/AndroidManifest.xml b/tests/unit/AndroidManifest.xml
index 5d4bdf7..2853f31 100644
--- a/tests/unit/AndroidManifest.xml
+++ b/tests/unit/AndroidManifest.xml
@@ -49,6 +49,9 @@
     <uses-permission android:name="android.permission.NETWORK_FACTORY" />
     <uses-permission android:name="android.permission.NETWORK_STATS_PROVIDER" />
     <uses-permission android:name="android.permission.CONTROL_OEM_PAID_NETWORK_PREFERENCE" />
+    <!-- Workaround for flakes where the launcher package is not found despite the <queries> tag
+         below (b/286550950). -->
+    <uses-permission android:name="android.permission.QUERY_ALL_PACKAGES" />
 
     <!-- Declare the intent that the test intends to query. This is necessary for
          UiDevice.getLauncherPackageName which is used in NetworkNotificationManagerTest
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 2fccdcb..c8cbce1 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -10772,6 +10772,8 @@
         final RouteInfo ipv4Subnet = new RouteInfo(myIpv4, null, MOBILE_IFNAME);
         final RouteInfo stackedDefault =
                 new RouteInfo((IpPrefix) null, myIpv4.getAddress(), CLAT_MOBILE_IFNAME);
+        final BaseNetdUnsolicitedEventListener netdUnsolicitedListener =
+                getRegisteredNetdUnsolicitedEventListener();
 
         final NetworkRequest networkRequest = new NetworkRequest.Builder()
                 .addTransportType(TRANSPORT_CELLULAR)
@@ -10839,7 +10841,6 @@
         assertRoutesRemoved(cellNetId, ipv4Subnet);
 
         // When NAT64 prefix discovery succeeds, LinkProperties are updated and clatd is started.
-        Nat464Xlat clat = getNat464Xlat(mCellAgent);
         assertNull(mCm.getLinkProperties(mCellAgent.getNetwork()).getNat64Prefix());
         mService.mResolverUnsolEventCallback.onNat64PrefixEvent(
                 makeNat64PrefixEvent(cellNetId, PREFIX_OPERATION_ADDED, kNat64PrefixString, 96));
@@ -10850,7 +10851,8 @@
         verifyClatdStart(null /* inOrder */, MOBILE_IFNAME, cellNetId, kNat64Prefix.toString());
 
         // Clat iface comes up. Expect stacked link to be added.
-        clat.interfaceLinkStateChanged(CLAT_MOBILE_IFNAME, true);
+        netdUnsolicitedListener.onInterfaceLinkStateChanged(
+                CLAT_MOBILE_IFNAME, true);
         networkCallback.expect(LINK_PROPERTIES_CHANGED, mCellAgent);
         List<LinkProperties> stackedLps = mCm.getLinkProperties(mCellAgent.getNetwork())
                 .getStackedLinks();
@@ -10896,7 +10898,7 @@
                 kOtherNat64Prefix.toString());
         networkCallback.expect(LINK_PROPERTIES_CHANGED, mCellAgent,
                 cb -> cb.getLp().getNat64Prefix().equals(kOtherNat64Prefix));
-        clat.interfaceLinkStateChanged(CLAT_MOBILE_IFNAME, true);
+        netdUnsolicitedListener.onInterfaceLinkStateChanged(CLAT_MOBILE_IFNAME, true);
         networkCallback.expect(LINK_PROPERTIES_CHANGED, mCellAgent,
                 cb -> cb.getLp().getStackedLinks().size() == 1);
         assertRoutesAdded(cellNetId, stackedDefault);
@@ -10924,7 +10926,7 @@
         assertRoutesRemoved(cellNetId, stackedDefault);
 
         // The interface removed callback happens but has no effect after stop is called.
-        clat.interfaceRemoved(CLAT_MOBILE_IFNAME);
+        netdUnsolicitedListener.onInterfaceRemoved(CLAT_MOBILE_IFNAME);
         networkCallback.assertNoCallback();
         verify(mMockNetd, times(1)).networkRemoveInterface(cellNetId, CLAT_MOBILE_IFNAME);
 
@@ -10961,7 +10963,7 @@
         verifyClatdStart(null /* inOrder */, MOBILE_IFNAME, cellNetId, kNat64Prefix.toString());
 
         // Clat iface comes up. Expect stacked link to be added.
-        clat.interfaceLinkStateChanged(CLAT_MOBILE_IFNAME, true);
+        netdUnsolicitedListener.onInterfaceLinkStateChanged(CLAT_MOBILE_IFNAME, true);
         networkCallback.expect(LINK_PROPERTIES_CHANGED, mCellAgent,
                 cb -> cb.getLp().getStackedLinks().size() == 1
                         && cb.getLp().getNat64Prefix() != null);
@@ -11029,8 +11031,7 @@
 
         // Clatd is started and clat iface comes up. Expect stacked link to be added.
         verifyClatdStart(null /* inOrder */, MOBILE_IFNAME, cellNetId, kNat64Prefix.toString());
-        clat = getNat464Xlat(mCellAgent);
-        clat.interfaceLinkStateChanged(CLAT_MOBILE_IFNAME, true /* up */);
+        netdUnsolicitedListener.onInterfaceLinkStateChanged(CLAT_MOBILE_IFNAME, true /* up */);
         networkCallback.expect(LINK_PROPERTIES_CHANGED, mCellAgent,
                 cb -> cb.getLp().getStackedLinks().size() == 1
                         && cb.getLp().getNat64Prefix().equals(kNat64Prefix));
diff --git a/tests/unit/java/com/android/server/connectivity/Nat464XlatTest.java b/tests/unit/java/com/android/server/connectivity/Nat464XlatTest.java
index 58c0114..2fe8713 100644
--- a/tests/unit/java/com/android/server/connectivity/Nat464XlatTest.java
+++ b/tests/unit/java/com/android/server/connectivity/Nat464XlatTest.java
@@ -86,7 +86,6 @@
     @Mock ClatCoordinator mClatCoordinator;
 
     TestLooper mLooper;
-    Handler mHandler;
     NetworkAgentConfig mAgentConfig = new NetworkAgentConfig();
 
     Nat464Xlat makeNat464Xlat(boolean isCellular464XlatEnabled) {
@@ -96,6 +95,14 @@
             }
         };
 
+        // The test looper needs to be created here on the test case thread and not in setUp,
+        // because setUp and test cases are run in different threads. Creating the test looper in
+        // setUp would make Looper.getThread() return the setUp thread, which does not match the
+        // test case thread that is actually used to process the messages.
+        mLooper = new TestLooper();
+        final Handler handler = new Handler(mLooper.getLooper());
+        doReturn(handler).when(mNai).handler();
+
         return new Nat464Xlat(mNai, mNetd, mDnsResolver, deps) {
             @Override protected int getNetId() {
                 return NETID;
@@ -117,9 +124,6 @@
 
     @Before
     public void setUp() throws Exception {
-        mLooper = new TestLooper();
-        mHandler = new Handler(mLooper.getLooper());
-
         MockitoAnnotations.initMocks(this);
 
         mNai.linkProperties = new LinkProperties();
@@ -130,7 +134,6 @@
         markNetworkConnected();
         when(mNai.connService()).thenReturn(mConnectivity);
         when(mNai.netAgentConfig()).thenReturn(mAgentConfig);
-        when(mNai.handler()).thenReturn(mHandler);
         final InterfaceConfigurationParcel mConfig = new InterfaceConfigurationParcel();
         when(mNetd.interfaceGetCfg(eq(STACKED_IFACE))).thenReturn(mConfig);
         mConfig.ipv4Addr = ADDR.getAddress().getHostAddress();
@@ -272,8 +275,7 @@
         verifyClatdStart(null /* inOrder */);
 
         // Stacked interface up notification arrives.
-        nat.interfaceLinkStateChanged(STACKED_IFACE, true);
-        mLooper.dispatchNext();
+        nat.handleInterfaceLinkStateChanged(STACKED_IFACE, true);
 
         verify(mNetd).interfaceGetCfg(eq(STACKED_IFACE));
         verify(mConnectivity).handleUpdateLinkProperties(eq(mNai), c.capture());
@@ -294,8 +296,7 @@
         // Verify the generated v6 is reset when clat is stopped.
         assertNull(nat.mIPv6Address);
         // Stacked interface removed notification arrives and is ignored.
-        nat.interfaceRemoved(STACKED_IFACE);
-        mLooper.dispatchNext();
+        nat.handleInterfaceRemoved(STACKED_IFACE);
 
         verifyNoMoreInteractions(mNetd, mConnectivity);
     }
@@ -324,8 +325,7 @@
         verifyClatdStart(inOrder);
 
         // Stacked interface up notification arrives.
-        nat.interfaceLinkStateChanged(STACKED_IFACE, true);
-        mLooper.dispatchNext();
+        nat.handleInterfaceLinkStateChanged(STACKED_IFACE, true);
 
         inOrder.verify(mConnectivity).handleUpdateLinkProperties(eq(mNai), c.capture());
         assertFalse(c.getValue().getStackedLinks().isEmpty());
@@ -344,10 +344,8 @@
 
         if (interfaceRemovedFirst) {
             // Stacked interface removed notification arrives and is ignored.
-            nat.interfaceRemoved(STACKED_IFACE);
-            mLooper.dispatchNext();
-            nat.interfaceLinkStateChanged(STACKED_IFACE, false);
-            mLooper.dispatchNext();
+            nat.handleInterfaceRemoved(STACKED_IFACE);
+            nat.handleInterfaceLinkStateChanged(STACKED_IFACE, false);
         }
 
         assertTrue(c.getValue().getStackedLinks().isEmpty());
@@ -361,15 +359,12 @@
 
         if (!interfaceRemovedFirst) {
             // Stacked interface removed notification arrives and is ignored.
-            nat.interfaceRemoved(STACKED_IFACE);
-            mLooper.dispatchNext();
-            nat.interfaceLinkStateChanged(STACKED_IFACE, false);
-            mLooper.dispatchNext();
+            nat.handleInterfaceRemoved(STACKED_IFACE);
+            nat.handleInterfaceLinkStateChanged(STACKED_IFACE, false);
         }
 
         // Stacked interface up notification arrives.
-        nat.interfaceLinkStateChanged(STACKED_IFACE, true);
-        mLooper.dispatchNext();
+        nat.handleInterfaceLinkStateChanged(STACKED_IFACE, true);
 
         inOrder.verify(mConnectivity).handleUpdateLinkProperties(eq(mNai), c.capture());
         assertFalse(c.getValue().getStackedLinks().isEmpty());
@@ -411,8 +406,7 @@
         verifyClatdStart(null /* inOrder */);
 
         // Stacked interface up notification arrives.
-        nat.interfaceLinkStateChanged(STACKED_IFACE, true);
-        mLooper.dispatchNext();
+        nat.handleInterfaceLinkStateChanged(STACKED_IFACE, true);
 
         verify(mNetd).interfaceGetCfg(eq(STACKED_IFACE));
         verify(mConnectivity, times(1)).handleUpdateLinkProperties(eq(mNai), c.capture());
@@ -421,8 +415,7 @@
         assertRunning(nat);
 
         // Stacked interface removed notification arrives (clatd crashed, ...).
-        nat.interfaceRemoved(STACKED_IFACE);
-        mLooper.dispatchNext();
+        nat.handleInterfaceRemoved(STACKED_IFACE);
 
         verifyClatdStop(null /* inOrder */);
         verify(mConnectivity, times(2)).handleUpdateLinkProperties(eq(mNai), c.capture());
@@ -457,12 +450,10 @@
         assertIdle(nat);
 
         // In-flight interface up notification arrives: no-op
-        nat.interfaceLinkStateChanged(STACKED_IFACE, true);
-        mLooper.dispatchNext();
+        nat.handleInterfaceLinkStateChanged(STACKED_IFACE, true);
 
         // Interface removed notification arrives after stopClatd() takes effect: no-op.
-        nat.interfaceRemoved(STACKED_IFACE);
-        mLooper.dispatchNext();
+        nat.handleInterfaceRemoved(STACKED_IFACE);
 
         assertIdle(nat);