Snap for 12327203 from a6648cbbf2a3d9aa94c9fd2bddec33e1e0dd774e to 24Q4-release

Change-Id: I55bdf605ac3c0e02f0c9f2b16003dd9a9aaa4533
diff --git a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
index 1eb6255..d5c6a8e 100644
--- a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
+++ b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
@@ -568,6 +568,12 @@
         return nif.getMTU();
     }
 
+    protected int getIndexByName(String ifaceName) throws SocketException {
+        NetworkInterface nif = NetworkInterface.getByName(ifaceName);
+        assertNotNull("Can't get NetworkInterface object for " + ifaceName, nif);
+        return nif.getIndex();
+    }
+
     protected TapPacketReader makePacketReader(final TestNetworkInterface iface) throws Exception {
         FileDescriptor fd = iface.getFileDescriptor().getFileDescriptor();
         return makePacketReader(fd, getMTU(iface));
@@ -968,6 +974,11 @@
         return Struct.parse(Ipv6Header.class, ByteBuffer.wrap(expectedPacket)).srcIp;
     }
 
+    protected String getUpstreamInterfaceName() {
+        if (mUpstreamReader == null) return null;
+        return mUpstreamTracker.getTestIface().getInterfaceName();
+    }
+
     protected <T> List<T> toList(T... array) {
         return Arrays.asList(array);
     }
diff --git a/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java b/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java
index 049f5f0..32b2f3e 100644
--- a/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java
+++ b/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java
@@ -1066,24 +1066,34 @@
         runUdp4Test();
     }
 
-    private ClatEgress4Value getClatEgress4Value() throws Exception {
+    private ClatEgress4Value getClatEgress4Value(int clatIfaceIndex) throws Exception {
         // Command: dumpsys connectivity clatEgress4RawBpfMap
         final String[] args = new String[] {DUMPSYS_CLAT_RAWMAP_EGRESS4_ARG};
         final HashMap<ClatEgress4Key, ClatEgress4Value> egress4Map = pollRawMapFromDump(
                 ClatEgress4Key.class, ClatEgress4Value.class, Context.CONNECTIVITY_SERVICE, args);
         assertNotNull(egress4Map);
-        assertEquals(1, egress4Map.size());
-        return egress4Map.entrySet().iterator().next().getValue();
+        for (Map.Entry<ClatEgress4Key, ClatEgress4Value> entry : egress4Map.entrySet()) {
+            ClatEgress4Key key = entry.getKey();
+            if (key.iif == clatIfaceIndex) {
+                return entry.getValue();
+            }
+        }
+        return null;
     }
 
-    private ClatIngress6Value getClatIngress6Value() throws Exception {
+    private ClatIngress6Value getClatIngress6Value(int ifaceIndex) throws Exception {
         // Command: dumpsys connectivity clatIngress6RawBpfMap
         final String[] args = new String[] {DUMPSYS_CLAT_RAWMAP_INGRESS6_ARG};
         final HashMap<ClatIngress6Key, ClatIngress6Value> ingress6Map = pollRawMapFromDump(
                 ClatIngress6Key.class, ClatIngress6Value.class, Context.CONNECTIVITY_SERVICE, args);
         assertNotNull(ingress6Map);
-        assertEquals(1, ingress6Map.size());
-        return ingress6Map.entrySet().iterator().next().getValue();
+        for (Map.Entry<ClatIngress6Key, ClatIngress6Value> entry : ingress6Map.entrySet()) {
+            ClatIngress6Key key = entry.getKey();
+            if (key.iif == ifaceIndex) {
+                return entry.getValue();
+            }
+        }
+        return null;
     }
 
     /**
@@ -1115,8 +1125,13 @@
         final Inet6Address clatIp6 = getClatIpv6Address(tester, tethered);
 
         // Get current values before sending packets.
-        final ClatEgress4Value oldEgress4 = getClatEgress4Value();
-        final ClatIngress6Value oldIngress6 = getClatIngress6Value();
+        final String ifaceName = getUpstreamInterfaceName();
+        final int ifaceIndex = getIndexByName(ifaceName);
+        final int clatIfaceIndex = getIndexByName("v4-" + ifaceName);
+        final ClatEgress4Value oldEgress4 = getClatEgress4Value(clatIfaceIndex);
+        final ClatIngress6Value oldIngress6 = getClatIngress6Value(ifaceIndex);
+        assertNotNull(oldEgress4);
+        assertNotNull(oldIngress6);
 
         // Send an IPv4 UDP packet in original direction.
         // IPv4 packet -- CLAT translation --> IPv6 packet
@@ -1145,8 +1160,10 @@
                 ByteBuffer.wrap(payload), l2mtu);
 
         // After sending test packets, get stats again to verify their differences.
-        final ClatEgress4Value newEgress4 = getClatEgress4Value();
-        final ClatIngress6Value newIngress6 = getClatIngress6Value();
+        final ClatEgress4Value newEgress4 = getClatEgress4Value(clatIfaceIndex);
+        final ClatIngress6Value newIngress6 = getClatIngress6Value(ifaceIndex);
+        assertNotNull(newEgress4);
+        assertNotNull(newIngress6);
 
         assertEquals(RX_UDP_PACKET_COUNT + fragPktCnt, newIngress6.packets - oldIngress6.packets);
         assertEquals(RX_UDP_PACKET_COUNT * RX_UDP_PACKET_SIZE + fragRxBytes,
diff --git a/bpf/headers/include/bpf_helpers.h b/bpf/headers/include/bpf_helpers.h
index 1a9fd31..ac5ffda 100644
--- a/bpf/headers/include/bpf_helpers.h
+++ b/bpf/headers/include/bpf_helpers.h
@@ -291,6 +291,12 @@
         bpf_ringbuf_submit_unsafe(v, 0);                                       \
     }
 
+#define DEFINE_BPF_RINGBUF(the_map, ValueType, size_bytes, usr, grp, md)                \
+    DEFINE_BPF_RINGBUF_EXT(the_map, ValueType, size_bytes, usr, grp, md,                \
+                           DEFAULT_BPF_MAP_SELINUX_CONTEXT, DEFAULT_BPF_MAP_PIN_SUBDIR, \
+                           PRIVATE, BPFLOADER_MIN_VER, BPFLOADER_MAX_VER,               \
+                           LOAD_ON_ENG, LOAD_ON_USER, LOAD_ON_USERDEBUG)
+
 /* There exist buggy kernels with pre-T OS, that due to
  * kernel patch "[ALPS05162612] bpf: fix ubsan error"
  * do not support userspace writes into non-zero index of bpf map arrays.
@@ -349,11 +355,17 @@
 #error "Bpf Map UID must be left at default of AID_ROOT for BpfLoader prior to v0.28"
 #endif
 
-#define DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md)     \
-    DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md,         \
-                       DEFAULT_BPF_MAP_SELINUX_CONTEXT, DEFAULT_BPF_MAP_PIN_SUBDIR, PRIVATE, \
-                       BPFLOADER_MIN_VER, BPFLOADER_MAX_VER, LOAD_ON_ENG,                    \
-                       LOAD_ON_USER, LOAD_ON_USERDEBUG)
+// for maps not meant to be accessed from userspace
+#define DEFINE_BPF_MAP_KERNEL_INTERNAL(the_map, TYPE, KeyType, ValueType, num_entries)           \
+    DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, AID_ROOT, AID_ROOT,       \
+                       0000, "fs_bpf_loader", "", PRIVATE, BPFLOADER_MIN_VER, BPFLOADER_MAX_VER, \
+                       LOAD_ON_ENG, LOAD_ON_USER, LOAD_ON_USERDEBUG)
+
+#define DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md) \
+    DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md,     \
+                       DEFAULT_BPF_MAP_SELINUX_CONTEXT, DEFAULT_BPF_MAP_PIN_SUBDIR,      \
+                       PRIVATE, BPFLOADER_MIN_VER, BPFLOADER_MAX_VER,                    \
+                       LOAD_ON_ENG, LOAD_ON_USER, LOAD_ON_USERDEBUG)
 
 #define DEFINE_BPF_MAP(the_map, TYPE, KeyType, ValueType, num_entries) \
     DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \
diff --git a/bpf/loader/NetBpfLoad.cpp b/bpf/loader/NetBpfLoad.cpp
index 04393a0..4767dfa 100644
--- a/bpf/loader/NetBpfLoad.cpp
+++ b/bpf/loader/NetBpfLoad.cpp
@@ -60,7 +60,14 @@
 #include "bpf_map_def.h"
 
 using android::base::EndsWith;
+using android::base::GetIntProperty;
+using android::base::GetProperty;
+using android::base::InitLogging;
+using android::base::KernelLogger;
+using android::base::SetProperty;
+using android::base::Split;
 using android::base::StartsWith;
+using android::base::Tokenize;
 using android::base::unique_fd;
 using std::ifstream;
 using std::ios;
@@ -90,6 +97,8 @@
     net_shared,         // (T+) fs_bpf_net_shared    /sys/fs/bpf/net_shared
     netd_readonly,      // (T+) fs_bpf_netd_readonly /sys/fs/bpf/netd_readonly
     netd_shared,        // (T+) fs_bpf_netd_shared   /sys/fs/bpf/netd_shared
+    loader,             // (U+) fs_bpf_loader        /sys/fs/bpf/loader
+                        // on T due to lack of sepolicy/genfscon rules it behaves simply as 'fs_bpf'
 };
 
 static constexpr domain AllDomains[] = {
@@ -99,6 +108,7 @@
     domain::net_shared,
     domain::netd_readonly,
     domain::netd_shared,
+    domain::loader,
 };
 
 static constexpr bool specified(domain d) {
@@ -112,7 +122,7 @@
 
 // Returns the build type string (from ro.build.type).
 const std::string& getBuildType() {
-    static std::string t = android::base::GetProperty("ro.build.type", "unknown");
+    static std::string t = GetProperty("ro.build.type", "unknown");
     return t;
 }
 
@@ -144,6 +154,7 @@
         case domain::net_shared:    return "fs_bpf_net_shared";
         case domain::netd_readonly: return "fs_bpf_netd_readonly";
         case domain::netd_shared:   return "fs_bpf_netd_shared";
+        case domain::loader:        return "fs_bpf_loader";
     }
 }
 
@@ -167,6 +178,7 @@
         case domain::net_shared:    return "net_shared/";
         case domain::netd_readonly: return "netd_readonly/";
         case domain::netd_shared:   return "netd_shared/";
+        case domain::loader:        return "loader/";
     }
 };
 
@@ -184,7 +196,7 @@
 
 static string pathToObjName(const string& path) {
     // extract everything after the final slash, ie. this is the filename 'foo@1.o' or 'bar.o'
-    string filename = android::base::Split(path, "/").back();
+    string filename = Split(path, "/").back();
     // strip off everything from the final period onwards (strip '.o' suffix), ie. 'foo@1' or 'bar'
     string name = filename.substr(0, filename.find_last_of('.'));
     // strip any potential @1 suffix, this will leave us with just 'foo' or 'bar'
@@ -1016,7 +1028,7 @@
 
             if (!fd.ok()) {
                 if (log_buf.size()) {
-                    vector<string> lines = android::base::Split(log_buf.data(), "\n");
+                    vector<string> lines = Split(log_buf.data(), "\n");
 
                     ALOGW("BPF_PROG_LOAD - BEGIN log_buf contents:");
                     for (const auto& line : lines) ALOGW("%s", line.c_str());
@@ -1247,7 +1259,7 @@
 // to include a newline to match 'echo "value" > /proc/sys/...foo' behaviour,
 // which is usually how kernel devs test the actual sysctl interfaces.
 static int writeProcSysFile(const char *filename, const char *value) {
-    base::unique_fd fd(open(filename, O_WRONLY | O_CLOEXEC));
+    unique_fd fd(open(filename, O_WRONLY | O_CLOEXEC));
     if (fd < 0) {
         const int err = errno;
         ALOGE("open('%s', O_WRONLY | O_CLOEXEC) -> %s", filename, strerror(err));
@@ -1324,7 +1336,7 @@
 }
 
 static bool hasGSM() {
-    static string ph = base::GetProperty("gsm.current.phone-type", "");
+    static string ph = GetProperty("gsm.current.phone-type", "");
     static bool gsm = (ph != "");
     static bool logged = false;
     if (!logged) {
@@ -1337,7 +1349,7 @@
 static bool isTV() {
     if (hasGSM()) return false;  // TVs don't do GSM
 
-    static string key = base::GetProperty("ro.oem.key1", "");
+    static string key = GetProperty("ro.oem.key1", "");
     static bool tv = StartsWith(key, "ATV00");
     static bool logged = false;
     if (!logged) {
@@ -1348,10 +1360,10 @@
 }
 
 static bool isWear() {
-    static string wearSdkStr = base::GetProperty("ro.cw_build.wear_sdk.version", "");
-    static int wearSdkInt = base::GetIntProperty("ro.cw_build.wear_sdk.version", 0);
-    static string buildChars = base::GetProperty("ro.build.characteristics", "");
-    static vector<string> v = base::Tokenize(buildChars, ",");
+    static string wearSdkStr = GetProperty("ro.cw_build.wear_sdk.version", "");
+    static int wearSdkInt = GetIntProperty("ro.cw_build.wear_sdk.version", 0);
+    static string buildChars = GetProperty("ro.build.characteristics", "");
+    static vector<string> v = Tokenize(buildChars, ",");
     static bool watch = (std::find(v.begin(), v.end(), "watch") != v.end());
     static bool wear = (wearSdkInt > 0) || watch;
     static bool logged = false;
@@ -1368,7 +1380,7 @@
 
     // Any released device will have codename REL instead of a 'real' codename.
     // For safety: default to 'REL' so we default to unreleased=false on failure.
-    const bool unreleased = (base::GetProperty("ro.build.version.codename", "REL") != "REL");
+    const bool unreleased = (GetProperty("ro.build.version.codename", "REL") != "REL");
 
     // goog/main device_api_level is bumped *way* before aosp/main api level
     // (the latter only gets bumped during the push of goog/main to aosp/main)
@@ -1399,7 +1411,7 @@
     const bool isAtLeastV = (effective_api_level >= __ANDROID_API_V__);
     const bool isAtLeastW = (effective_api_level >  __ANDROID_API_V__);  // TODO: switch to W
 
-    const int first_api_level = base::GetIntProperty("ro.board.first_api_level", effective_api_level);
+    const int first_api_level = GetIntProperty("ro.board.first_api_level", effective_api_level);
 
     // last in U QPR2 beta1
     const bool has_platform_bpfloader_rc = exists("/system/etc/init/bpfloader.rc");
@@ -1620,7 +1632,7 @@
 
     int key = 1;
     int value = 123;
-    base::unique_fd map(
+    unique_fd map(
             createMap(BPF_MAP_TYPE_ARRAY, sizeof(key), sizeof(value), 2, 0));
     if (writeToMapEntry(map, &key, &value, BPF_ANY)) {
         ALOGE("Critical kernel bug - failure to write into index 1 of 2 element bpf map array.");
@@ -1652,11 +1664,11 @@
 }  // namespace android
 
 int main(int argc, char** argv, char * const envp[]) {
-    android::base::InitLogging(argv, &android::base::KernelLogger);
+    InitLogging(argv, &KernelLogger);
 
     if (argc == 2 && !strcmp(argv[1], "done")) {
         // we're being re-exec'ed from platform bpfloader to 'finalize' things
-        if (!android::base::SetProperty("bpf.progs_loaded", "1")) {
+        if (!SetProperty("bpf.progs_loaded", "1")) {
             ALOGE("Failed to set bpf.progs_loaded property to 1.");
             return 125;
         }
diff --git a/bpf/progs/bpf_net_helpers.h b/bpf/progs/bpf_net_helpers.h
index a86c3e6..a5664ba 100644
--- a/bpf/progs/bpf_net_helpers.h
+++ b/bpf/progs/bpf_net_helpers.h
@@ -139,6 +139,24 @@
     if (skb->data_end - skb->data < len) bpf_skb_pull_data(skb, len);
 }
 
+// anti-compiler-optimizer no-op: explicitly force full calculation of 'v'
+//
+// The use for this is to force full calculation of a complex arithmetic (likely binary
+// bitops) value, and then check the result only once (thus likely reducing the number
+// of required conditional jump instructions that badly affect bpf verifier runtime)
+//
+// The compiler cannot look into the assembly statement, so it doesn't know it does nothing.
+// Since the statement takes 'v' as both input and output in a register (+r),
+// the compiler must fully calculate the precise value of 'v' before this,
+// and must use the (possibly modified) value of 'v' afterwards (thus cannot
+// do funky optimizations to use partial results from before the asm).
+//
+// As this is not flagged 'volatile' this may still be moved out of a loop,
+// or even entirely optimized out if 'v' is never used afterwards.
+//
+// See: https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html
+#define COMPILER_FORCE_CALCULATION(v) asm ("" : "+r" (v))
+
 struct egress_bool { bool egress; };
 #define INGRESS ((struct egress_bool){ .egress = false })
 #define EGRESS ((struct egress_bool){ .egress = true })
diff --git a/bpf/progs/dscpPolicy.c b/bpf/progs/dscpPolicy.c
index 39f2961..de9723d 100644
--- a/bpf/progs/dscpPolicy.c
+++ b/bpf/progs/dscpPolicy.c
@@ -25,8 +25,8 @@
 
 // The cache is never read nor written by userspace and is indexed by socket cookie % CACHE_MAP_SIZE
 #define CACHE_MAP_SIZE 32  // should be a power of two so we can % cheaply
-DEFINE_BPF_MAP_GRO(socket_policy_cache_map, PERCPU_ARRAY, uint32_t, RuleEntry, CACHE_MAP_SIZE,
-                   AID_SYSTEM)
+DEFINE_BPF_MAP_KERNEL_INTERNAL(socket_policy_cache_map, PERCPU_ARRAY, uint32_t, RuleEntry,
+                               CACHE_MAP_SIZE)
 
 DEFINE_BPF_MAP_GRW(ipv4_dscp_policies_map, ARRAY, uint32_t, DscpPolicy, MAX_POLICIES, AID_SYSTEM)
 DEFINE_BPF_MAP_GRW(ipv6_dscp_policies_map, ARRAY, uint32_t, DscpPolicy, MAX_POLICIES, AID_SYSTEM)
@@ -113,14 +113,30 @@
     // this array lookup cannot actually fail
     RuleEntry* existing_rule = bpf_socket_policy_cache_map_lookup_elem(&cacheid);
 
-    if (existing_rule &&
-        v6_equal(src_ip, existing_rule->src_ip) &&
-        v6_equal(dst_ip, existing_rule->dst_ip) &&
-        skb->ifindex == existing_rule->ifindex &&
-        sport == existing_rule->src_port &&
-        dport == existing_rule->dst_port &&
-        protocol == existing_rule->proto) {
-        if (existing_rule->dscp_val < 0) return;
+    if (!existing_rule) return; // impossible
+
+    uint64_t nomatch = 0;
+    nomatch |= v6_not_equal(src_ip, existing_rule->src_ip);
+    nomatch |= v6_not_equal(dst_ip, existing_rule->dst_ip);
+    nomatch |= (skb->ifindex ^ existing_rule->ifindex);
+    nomatch |= (sport ^ existing_rule->src_port);
+    nomatch |= (dport ^ existing_rule->dst_port);
+    nomatch |= (protocol ^ existing_rule->proto);
+    COMPILER_FORCE_CALCULATION(nomatch);
+
+    /*
+     * After the above funky bitwise arithmetic we have 'nomatch == 0' iff
+     *   src_ip == existing_rule->src_ip &&
+     *   dst_ip == existing_rule->dst_ip &&
+     *   skb->ifindex == existing_rule->ifindex &&
+     *   sport == existing_rule->src_port &&
+     *   dport == existing_rule->dst_port &&
+     *   protocol == existing_rule->proto
+     */
+
+    if (!nomatch) {
+        if (existing_rule->dscp_val < 0) return;  // cached no-op
+
         if (ipv4) {
             uint8_t newTos = UPDATE_TOS(existing_rule->dscp_val, tos);
             bpf_l3_csum_replace(skb, l2_header_size + IP4_OFFSET(check), htons(tos), htons(newTos),
@@ -132,7 +148,7 @@
             bpf_skb_store_bytes(skb, l2_header_size, &new_first_be32, sizeof(__be32),
                 BPF_F_RECOMPUTE_CSUM);
         }
-        return;
+        return;  // cached DSCP mutation
     }
 
     // Linear scan ipv4_dscp_policies_map since no stored params match skb.
@@ -187,7 +203,8 @@
         }
     }
 
-    RuleEntry value = {
+    // Update cache with found policy.
+    *existing_rule = (RuleEntry){
         .src_ip = src_ip,
         .dst_ip = dst_ip,
         .ifindex = skb->ifindex,
@@ -197,9 +214,6 @@
         .dscp_val = new_dscp,
     };
 
-    // Update cache with found policy.
-    bpf_socket_policy_cache_map_update_elem(&cacheid, &value, BPF_ANY);
-
     if (new_dscp < 0) return;
 
     // Need to store bytes after updating map or program will not load.
diff --git a/bpf/progs/dscpPolicy.h b/bpf/progs/dscpPolicy.h
index 6a6b711..413fb0f 100644
--- a/bpf/progs/dscpPolicy.h
+++ b/bpf/progs/dscpPolicy.h
@@ -28,9 +28,6 @@
 #define v6_not_equal(a, b) ((v6_hi_be64(a) ^ v6_hi_be64(b)) \
                           | (v6_lo_be64(a) ^ v6_lo_be64(b)))
 
-// Returns 'a == b' as boolean
-#define v6_equal(a, b) (!v6_not_equal((a), (b)))
-
 typedef struct {
     struct in6_addr src_ip;
     struct in6_addr dst_ip;
diff --git a/service/src/com/android/server/connectivity/DscpPolicyValue.java b/service/src/com/android/server/connectivity/DscpPolicyValue.java
index 7162a4a..a9100ac 100644
--- a/service/src/com/android/server/connectivity/DscpPolicyValue.java
+++ b/service/src/com/android/server/connectivity/DscpPolicyValue.java
@@ -117,8 +117,8 @@
         this.proto = proto != -1 ? proto : 0;
 
         this.dscp = dscp;
-        this.match_src_ip = (this.src46 != EMPTY_ADDRESS_FIELD);
-        this.match_dst_ip = (this.dst46 != EMPTY_ADDRESS_FIELD);
+        this.match_src_ip = (src46 != null);
+        this.match_dst_ip = (dst46 != null);
         this.match_src_port = (srcPort != -1);
         this.match_proto = (proto != -1);
     }
diff --git a/staticlibs/framework/com/android/net/module/util/LocationPermissionChecker.java b/staticlibs/framework/com/android/net/module/util/LocationPermissionChecker.java
index 28c33f3..e4d25cd 100644
--- a/staticlibs/framework/com/android/net/module/util/LocationPermissionChecker.java
+++ b/staticlibs/framework/com/android/net/module/util/LocationPermissionChecker.java
@@ -117,7 +117,11 @@
     @VisibleForTesting(visibility = PRIVATE)
     public @LocationPermissionCheckStatus int checkLocationPermissionInternal(
             String pkgName, @Nullable String featureId, int uid, @Nullable String message) {
-        checkPackage(uid, pkgName);
+        try {
+            checkPackage(uid, pkgName);
+        } catch (SecurityException e) {
+            return ERROR_LOCATION_PERMISSION_MISSING;
+        }
 
         // Apps with NETWORK_SETTINGS, NETWORK_SETUP_WIZARD, NETWORK_STACK & MAINLINE_NETWORK_STACK
         // are granted a bypass.
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/LocationPermissionCheckerTest.java b/staticlibs/tests/unit/src/com/android/net/module/util/LocationPermissionCheckerTest.java
index c8f8656..d773374 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/LocationPermissionCheckerTest.java
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/LocationPermissionCheckerTest.java
@@ -18,7 +18,6 @@
 import static android.Manifest.permission.NETWORK_SETTINGS;
 
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.anyString;
@@ -47,7 +46,6 @@
 
 import com.android.testutils.DevSdkIgnoreRule;
 
-import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -242,9 +240,9 @@
         mWifiScanAllowApps = AppOpsManager.MODE_ALLOWED;
         setupTestCase();
 
-        assertThrows(SecurityException.class,
-                () -> mChecker.checkLocationPermissionInternal(
-                        TEST_PKG_NAME, TEST_FEATURE_ID, mUid, null));
+        final int result = mChecker.checkLocationPermissionInternal(
+                        TEST_PKG_NAME, TEST_FEATURE_ID, mUid, null);
+        assertEquals(LocationPermissionChecker.ERROR_LOCATION_PERMISSION_MISSING, result);
     }
 
     @Test
@@ -305,14 +303,4 @@
                         TEST_PKG_NAME, TEST_FEATURE_ID, mUid, null);
         assertEquals(LocationPermissionChecker.SUCCEEDED, result);
     }
-
-
-    private static void assertThrows(Class<? extends Exception> exceptionClass, Runnable r) {
-        try {
-            r.run();
-            Assert.fail("Expected " + exceptionClass + " to be thrown.");
-        } catch (Exception exception) {
-            assertTrue(exceptionClass.isInstance(exception));
-        }
-    }
 }
diff --git a/tests/cts/net/src/android/net/cts/DscpPolicyTest.kt b/tests/cts/net/src/android/net/cts/DscpPolicyTest.kt
index f73134a..041e6cb 100644
--- a/tests/cts/net/src/android/net/cts/DscpPolicyTest.kt
+++ b/tests/cts/net/src/android/net/cts/DscpPolicyTest.kt
@@ -298,7 +298,8 @@
     fun sendPacket(
         agent: TestableNetworkAgent,
         sendV6: Boolean,
-        dstPort: Int = 0
+        dstPort: Int = 0,
+        times: Int = 1
     ) {
         val testString = "test string"
         val testPacket = ByteBuffer.wrap(testString.toByteArray(Charsets.UTF_8))
@@ -308,9 +309,11 @@
                 IPPROTO_UDP)
         checkNotNull(agent.network).bindSocket(socket)
 
-        val originalPacket = testPacket.readAsArray()
-        Os.sendto(socket, originalPacket, 0 /* bytesOffset */, originalPacket.size, 0 /* flags */,
+        val origPacket = testPacket.readAsArray()
+        repeat(times) {
+            Os.sendto(socket, origPacket, 0 /* bytesOffset */, origPacket.size, 0 /* flags */,
                 if (sendV6) TEST_TARGET_IPV6_ADDR else TEST_TARGET_IPV4_ADDR, dstPort)
+        }
         Os.close(socket)
     }
 
@@ -400,10 +403,11 @@
         agent: TestableNetworkAgent,
         sendV6: Boolean = false,
         dscpValue: Int = 0,
-        dstPort: Int = 0
+        dstPort: Int = 0,
+        times: Int = 1
     ) {
-        var packetFound = false
-        sendPacket(agent, sendV6, dstPort)
+        var packetFound = 0
+        sendPacket(agent, sendV6, dstPort, times)
         // TODO: grab source port from socket in sendPacket
 
         Log.e(TAG, "find DSCP value:" + dscpValue)
@@ -424,10 +428,23 @@
             if (parsePacketIp(buffer, sendV6) && parsePacketPort(buffer, 0, dstPort)) {
                 Log.e(TAG, "DSCP value found")
                 assertEquals(dscpValue, dscp)
-                packetFound = true
+                packetFound++
             }
         }
-        assertTrue(packetFound)
+        assertTrue(packetFound == times)
+    }
+
+    fun validatePackets(
+        agent: TestableNetworkAgent,
+        sendV6: Boolean = false,
+        dscpValue: Int = 0,
+        dstPort: Int = 0
+    ) {
+        // We send two packets from the same socket to verify
+        // socket caching works correctly.
+        validatePacket(agent, sendV6, dscpValue, dstPort, 2)
+        // Try one more time from a different socket.
+        validatePacket(agent, sendV6, dscpValue, dstPort, 1)
     }
 
     fun doRemovePolicyTest(
@@ -453,10 +470,7 @@
             assertEquals(1, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
         }
-        validatePacket(agent, dscpValue = 1, dstPort = 4444)
-        // Send a second packet to validate that the stored BPF policy
-        // is correct for subsequent packets.
-        validatePacket(agent, dscpValue = 1, dstPort = 4444)
+        validatePackets(agent, dscpValue = 1, dstPort = 4444)
 
         agent.sendRemoveDscpPolicy(1)
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
@@ -475,7 +489,7 @@
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
         }
 
-        validatePacket(agent, dscpValue = 4, dstPort = 5555)
+        validatePackets(agent, dscpValue = 4, dstPort = 5555)
 
         agent.sendRemoveDscpPolicy(1)
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
@@ -494,10 +508,7 @@
             assertEquals(1, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
         }
-        validatePacket(agent, true, dscpValue = 1, dstPort = 4444)
-        // Send a second packet to validate that the stored BPF policy
-        // is correct for subsequent packets.
-        validatePacket(agent, true, dscpValue = 1, dstPort = 4444)
+        validatePackets(agent, true, dscpValue = 1, dstPort = 4444)
 
         agent.sendRemoveDscpPolicy(1)
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
@@ -515,7 +526,7 @@
             assertEquals(1, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
         }
-        validatePacket(agent, true, dscpValue = 4, dstPort = 5555)
+        validatePackets(agent, true, dscpValue = 4, dstPort = 5555)
 
         agent.sendRemoveDscpPolicy(1)
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
@@ -533,7 +544,7 @@
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(1, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
-            validatePacket(agent, dscpValue = 1, dstPort = 1111)
+            validatePackets(agent, dscpValue = 1, dstPort = 1111)
         }
 
         val policy2 = DscpPolicy.Builder(2, 1).setDestinationPortRange(Range(2222, 2222)).build()
@@ -541,7 +552,7 @@
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(2, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
-            validatePacket(agent, dscpValue = 1, dstPort = 2222)
+            validatePackets(agent, dscpValue = 1, dstPort = 2222)
         }
 
         val policy3 = DscpPolicy.Builder(3, 1).setDestinationPortRange(Range(3333, 3333)).build()
@@ -549,16 +560,16 @@
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(3, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
-            validatePacket(agent, dscpValue = 1, dstPort = 3333)
+            validatePackets(agent, dscpValue = 1, dstPort = 3333)
         }
 
         /* Remove Policies and check CE is no longer set */
         doRemovePolicyTest(agent, callback, 1)
-        validatePacket(agent, dscpValue = 0, dstPort = 1111)
+        validatePackets(agent, dscpValue = 0, dstPort = 1111)
         doRemovePolicyTest(agent, callback, 2)
-        validatePacket(agent, dscpValue = 0, dstPort = 2222)
+        validatePackets(agent, dscpValue = 0, dstPort = 2222)
         doRemovePolicyTest(agent, callback, 3)
-        validatePacket(agent, dscpValue = 0, dstPort = 3333)
+        validatePackets(agent, dscpValue = 0, dstPort = 3333)
     }
 
     @Test
@@ -569,7 +580,7 @@
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(1, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
-            validatePacket(agent, dscpValue = 1, dstPort = 1111)
+            validatePackets(agent, dscpValue = 1, dstPort = 1111)
         }
         doRemovePolicyTest(agent, callback, 1)
 
@@ -578,7 +589,7 @@
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(2, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
-            validatePacket(agent, dscpValue = 1, dstPort = 2222)
+            validatePackets(agent, dscpValue = 1, dstPort = 2222)
         }
         doRemovePolicyTest(agent, callback, 2)
 
@@ -587,7 +598,7 @@
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(3, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
-            validatePacket(agent, dscpValue = 1, dstPort = 3333)
+            validatePackets(agent, dscpValue = 1, dstPort = 3333)
         }
         doRemovePolicyTest(agent, callback, 3)
     }
@@ -601,7 +612,7 @@
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(1, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
-            validatePacket(agent, dscpValue = 1, dstPort = 1111)
+            validatePackets(agent, dscpValue = 1, dstPort = 1111)
         }
 
         val policy2 = DscpPolicy.Builder(2, 1).setDestinationPortRange(Range(2222, 2222)).build()
@@ -609,7 +620,7 @@
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(2, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
-            validatePacket(agent, dscpValue = 1, dstPort = 2222)
+            validatePackets(agent, dscpValue = 1, dstPort = 2222)
         }
 
         val policy3 = DscpPolicy.Builder(3, 1).setDestinationPortRange(Range(3333, 3333)).build()
@@ -617,7 +628,7 @@
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(3, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
-            validatePacket(agent, dscpValue = 1, dstPort = 3333)
+            validatePackets(agent, dscpValue = 1, dstPort = 3333)
         }
 
         /* Remove Policies and check CE is no longer set */
@@ -643,7 +654,7 @@
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(1, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
-            validatePacket(agent, dscpValue = 1, dstPort = 1111)
+            validatePackets(agent, dscpValue = 1, dstPort = 1111)
         }
 
         val policy2 = DscpPolicy.Builder(2, 1)
@@ -652,7 +663,7 @@
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(2, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
-            validatePacket(agent, dscpValue = 1, dstPort = 2222)
+            validatePackets(agent, dscpValue = 1, dstPort = 2222)
         }
 
         val policy3 = DscpPolicy.Builder(3, 1)
@@ -661,24 +672,24 @@
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(3, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
-            validatePacket(agent, dscpValue = 1, dstPort = 3333)
+            validatePackets(agent, dscpValue = 1, dstPort = 3333)
         }
 
         agent.sendRemoveAllDscpPolicies()
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(1, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_DELETED, it.status)
-            validatePacket(agent, false, dstPort = 1111)
+            validatePackets(agent, false, dstPort = 1111)
         }
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(2, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_DELETED, it.status)
-            validatePacket(agent, false, dstPort = 2222)
+            validatePackets(agent, false, dstPort = 2222)
         }
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(3, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_DELETED, it.status)
-            validatePacket(agent, false, dstPort = 3333)
+            validatePackets(agent, false, dstPort = 3333)
         }
     }
 
@@ -690,7 +701,7 @@
         agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
             assertEquals(1, it.policyId)
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
-            validatePacket(agent, dscpValue = 1, dstPort = 4444)
+            validatePackets(agent, dscpValue = 1, dstPort = 4444)
         }
 
         val policy2 = DscpPolicy.Builder(1, 1).setDestinationPortRange(Range(5555, 5555)).build()
@@ -700,8 +711,8 @@
             assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
 
             // Sending packet with old policy should fail
-            validatePacket(agent, dscpValue = 0, dstPort = 4444)
-            validatePacket(agent, dscpValue = 1, dstPort = 5555)
+            validatePackets(agent, dscpValue = 0, dstPort = 4444)
+            validatePackets(agent, dscpValue = 1, dstPort = 5555)
         }
 
         agent.sendRemoveDscpPolicy(1)
diff --git a/thread/tests/integration/Android.bp b/thread/tests/integration/Android.bp
index 71693af..59e8e19 100644
--- a/thread/tests/integration/Android.bp
+++ b/thread/tests/integration/Android.bp
@@ -58,6 +58,7 @@
     ],
     srcs: [
         "src/**/*.java",
+        "src/**/*.kt",
     ],
     compile_multilib: "both",
 }
diff --git a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
index 9e8dc3a..103282a 100644
--- a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
+++ b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
@@ -34,7 +34,6 @@
 
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ECHO_REPLY_TYPE;
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ECHO_REQUEST_TYPE;
-import static com.android.testutils.TestNetworkTrackerKt.initTestNetwork;
 import static com.android.testutils.TestPermissionUtil.runAsShell;
 
 import static com.google.common.truth.Truth.assertThat;
@@ -49,11 +48,9 @@
 import android.content.Context;
 import android.net.IpPrefix;
 import android.net.LinkAddress;
-import android.net.LinkProperties;
-import android.net.MacAddress;
-import android.net.RouteInfo;
 import android.net.thread.utils.FullThreadDevice;
 import android.net.thread.utils.InfraNetworkDevice;
+import android.net.thread.utils.IntegrationTestUtils;
 import android.net.thread.utils.OtDaemonController;
 import android.net.thread.utils.ThreadFeatureCheckerRule;
 import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresIpv6MulticastRouting;
@@ -634,32 +631,16 @@
     }
 
     private void setUpInfraNetwork() throws Exception {
-        LinkProperties lp = new LinkProperties();
-        // NAT64 feature requires the infra network to have an IPv4 default route.
-        lp.addRoute(
-                new RouteInfo(
-                        new IpPrefix("0.0.0.0/0") /* destination */,
-                        null /* gateway */,
-                        null,
-                        RouteInfo.RTN_UNICAST,
-                        1500 /* mtu */));
-        mInfraNetworkTracker =
-                runAsShell(
-                        MANAGE_TEST_NETWORKS,
-                        () -> initTestNetwork(mContext, lp, 5000 /* timeoutMs */));
-        String infraNetworkName = mInfraNetworkTracker.getTestIface().getInterfaceName();
-        mController.setTestNetworkAsUpstreamAndWait(infraNetworkName);
+        mInfraNetworkTracker = IntegrationTestUtils.setUpInfraNetwork(mContext, mController);
     }
 
     private void tearDownInfraNetwork() {
-        runAsShell(MANAGE_TEST_NETWORKS, () -> mInfraNetworkTracker.teardown());
+        IntegrationTestUtils.tearDownInfraNetwork(mInfraNetworkTracker);
     }
 
-    private void startInfraDeviceAndWaitForOnLinkAddr() throws Exception {
+    private void startInfraDeviceAndWaitForOnLinkAddr() {
         mInfraDevice =
-                new InfraNetworkDevice(MacAddress.fromString("1:2:3:4:5:6"), mInfraNetworkReader);
-        mInfraDevice.runSlaac(Duration.ofSeconds(60));
-        assertNotNull(mInfraDevice.ipv6Addr);
+                IntegrationTestUtils.startInfraDeviceAndWaitForOnLinkAddr(mInfraNetworkReader);
     }
 
     private void assertInfraLinkMemberOfGroup(Inet6Address address) throws Exception {
diff --git a/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.java b/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.java
deleted file mode 100644
index 82e9332..0000000
--- a/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.java
+++ /dev/null
@@ -1,563 +0,0 @@
-/*
- * Copyright (C) 2023 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package android.net.thread.utils;
-
-import static android.net.NetworkCapabilities.NET_CAPABILITY_LOCAL_NETWORK;
-import static android.system.OsConstants.IPPROTO_ICMP;
-import static android.system.OsConstants.IPPROTO_ICMPV6;
-
-import static com.android.compatibility.common.util.SystemUtil.runShellCommandOrThrow;
-import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ND_OPTION_PIO;
-import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ROUTER_ADVERTISEMENT;
-
-import static com.google.common.io.BaseEncoding.base16;
-import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
-
-import static org.junit.Assert.assertNotNull;
-
-import static java.util.concurrent.TimeUnit.MILLISECONDS;
-import static java.util.concurrent.TimeUnit.SECONDS;
-
-import android.net.ConnectivityManager;
-import android.net.InetAddresses;
-import android.net.LinkAddress;
-import android.net.Network;
-import android.net.NetworkCapabilities;
-import android.net.NetworkRequest;
-import android.net.TestNetworkInterface;
-import android.net.nsd.NsdManager;
-import android.net.nsd.NsdServiceInfo;
-import android.net.thread.ActiveOperationalDataset;
-import android.net.thread.ThreadNetworkController;
-import android.os.Build;
-import android.os.Handler;
-import android.os.SystemClock;
-
-import androidx.annotation.NonNull;
-import androidx.test.core.app.ApplicationProvider;
-
-import com.android.net.module.util.Struct;
-import com.android.net.module.util.structs.Icmpv4Header;
-import com.android.net.module.util.structs.Icmpv6Header;
-import com.android.net.module.util.structs.Ipv4Header;
-import com.android.net.module.util.structs.Ipv6Header;
-import com.android.net.module.util.structs.PrefixInformationOption;
-import com.android.net.module.util.structs.RaHeader;
-import com.android.testutils.HandlerUtils;
-import com.android.testutils.TapPacketReader;
-
-import com.google.common.util.concurrent.SettableFuture;
-
-import java.io.FileDescriptor;
-import java.io.IOException;
-import java.net.DatagramPacket;
-import java.net.DatagramSocket;
-import java.net.Inet4Address;
-import java.net.Inet6Address;
-import java.net.InetAddress;
-import java.net.InetSocketAddress;
-import java.net.SocketAddress;
-import java.nio.ByteBuffer;
-import java.time.Duration;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
-import java.util.function.Predicate;
-import java.util.function.Supplier;
-
-/** Static utility methods relating to Thread integration tests. */
-public final class IntegrationTestUtils {
-    // The timeout of join() after restarting ot-daemon. The device needs to send 6 Link Request
-    // every 5 seconds, followed by 4 Parent Request every second. So this value needs to be 40
-    // seconds to be safe
-    public static final Duration RESTART_JOIN_TIMEOUT = Duration.ofSeconds(40);
-    public static final Duration JOIN_TIMEOUT = Duration.ofSeconds(30);
-    public static final Duration LEAVE_TIMEOUT = Duration.ofSeconds(2);
-    public static final Duration CALLBACK_TIMEOUT = Duration.ofSeconds(1);
-    public static final Duration SERVICE_DISCOVERY_TIMEOUT = Duration.ofSeconds(20);
-
-    // A valid Thread Active Operational Dataset generated from OpenThread CLI "dataset init new".
-    private static final byte[] DEFAULT_DATASET_TLVS =
-            base16().decode(
-                            "0E080000000000010000000300001335060004001FFFE002"
-                                    + "08ACC214689BC40BDF0708FD64DB1225F47E0B0510F26B31"
-                                    + "53760F519A63BAFDDFFC80D2AF030F4F70656E5468726561"
-                                    + "642D643961300102D9A00410A245479C836D551B9CA557F7"
-                                    + "B9D351B40C0402A0FFF8");
-    public static final ActiveOperationalDataset DEFAULT_DATASET =
-            ActiveOperationalDataset.fromThreadTlvs(DEFAULT_DATASET_TLVS);
-
-    private IntegrationTestUtils() {}
-
-    /**
-     * Waits for the given {@link Supplier} to be true until given timeout.
-     *
-     * @param condition the condition to check
-     * @param timeout the time to wait for the condition before throwing
-     * @throws TimeoutException if the condition is still not met when the timeout expires
-     */
-    public static void waitFor(Supplier<Boolean> condition, Duration timeout)
-            throws TimeoutException {
-        final long intervalMills = 500;
-        final long timeoutMills = timeout.toMillis();
-
-        for (long i = 0; i < timeoutMills; i += intervalMills) {
-            if (condition.get()) {
-                return;
-            }
-            SystemClock.sleep(intervalMills);
-        }
-        if (condition.get()) {
-            return;
-        }
-        throw new TimeoutException("The condition failed to become true in " + timeout);
-    }
-
-    /**
-     * Creates a {@link TapPacketReader} given the {@link TestNetworkInterface} and {@link Handler}.
-     *
-     * @param testNetworkInterface the TUN interface of the test network
-     * @param handler the handler to process the packets
-     * @return the {@link TapPacketReader}
-     */
-    public static TapPacketReader newPacketReader(
-            TestNetworkInterface testNetworkInterface, Handler handler) {
-        FileDescriptor fd = testNetworkInterface.getFileDescriptor().getFileDescriptor();
-        final TapPacketReader reader =
-                new TapPacketReader(handler, fd, testNetworkInterface.getMtu());
-        handler.post(() -> reader.start());
-        HandlerUtils.waitForIdle(handler, 5000 /* timeout in milliseconds */);
-        return reader;
-    }
-
-    /**
-     * Waits for the Thread module to enter any state of the given {@code deviceRoles}.
-     *
-     * @param controller the {@link ThreadNetworkController}
-     * @param deviceRoles the desired device roles. See also {@link
-     *     ThreadNetworkController.DeviceRole}
-     * @param timeout the time to wait for the expected state before throwing
-     * @return the {@link ThreadNetworkController.DeviceRole} after waiting
-     * @throws TimeoutException if the device hasn't become any of expected roles until the timeout
-     *     expires
-     */
-    public static int waitForStateAnyOf(
-            ThreadNetworkController controller, List<Integer> deviceRoles, Duration timeout)
-            throws TimeoutException {
-        SettableFuture<Integer> future = SettableFuture.create();
-        ThreadNetworkController.StateCallback callback =
-                newRole -> {
-                    if (deviceRoles.contains(newRole)) {
-                        future.set(newRole);
-                    }
-                };
-        controller.registerStateCallback(directExecutor(), callback);
-        try {
-            return future.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
-        } catch (InterruptedException | ExecutionException e) {
-            throw new TimeoutException(
-                    String.format(
-                            "The device didn't become an expected role in %s: %s",
-                            timeout, e.getMessage()));
-        } finally {
-            controller.unregisterStateCallback(callback);
-        }
-    }
-
-    /**
-     * Polls for a packet from a given {@link TapPacketReader} that satisfies the {@code filter}.
-     *
-     * @param packetReader a TUN packet reader
-     * @param filter the filter to be applied on the packet
-     * @return the first IPv6 packet that satisfies the {@code filter}. If it has waited for more
-     *     than 3000ms to read the next packet, the method will return null
-     */
-    public static byte[] pollForPacket(TapPacketReader packetReader, Predicate<byte[]> filter) {
-        byte[] packet;
-        while ((packet = packetReader.poll(3000 /* timeoutMs */, filter)) != null) {
-            return packet;
-        }
-        return null;
-    }
-
-    /** Returns {@code true} if {@code packet} is an ICMPv4 packet of given {@code type}. */
-    public static boolean isExpectedIcmpv4Packet(byte[] packet, int type) {
-        ByteBuffer buf = makeByteBuffer(packet);
-        Ipv4Header header = extractIpv4Header(buf);
-        if (header == null) {
-            return false;
-        }
-        if (header.protocol != (byte) IPPROTO_ICMP) {
-            return false;
-        }
-        try {
-            return Struct.parse(Icmpv4Header.class, buf).type == (short) type;
-        } catch (IllegalArgumentException ignored) {
-            // It's fine that the passed in packet is malformed because it's could be sent
-            // by anybody.
-        }
-        return false;
-    }
-
-    /** Returns {@code true} if {@code packet} is an ICMPv6 packet of given {@code type}. */
-    public static boolean isExpectedIcmpv6Packet(byte[] packet, int type) {
-        ByteBuffer buf = makeByteBuffer(packet);
-        Ipv6Header header = extractIpv6Header(buf);
-        if (header == null) {
-            return false;
-        }
-        if (header.nextHeader != (byte) IPPROTO_ICMPV6) {
-            return false;
-        }
-        try {
-            return Struct.parse(Icmpv6Header.class, buf).type == (short) type;
-        } catch (IllegalArgumentException ignored) {
-            // It's fine that the passed in packet is malformed because it's could be sent
-            // by anybody.
-        }
-        return false;
-    }
-
-    public static boolean isFrom(byte[] packet, InetAddress src) {
-        if (src instanceof Inet4Address) {
-            return isFromIpv4Source(packet, (Inet4Address) src);
-        } else if (src instanceof Inet6Address) {
-            return isFromIpv6Source(packet, (Inet6Address) src);
-        }
-        return false;
-    }
-
-    public static boolean isTo(byte[] packet, InetAddress dest) {
-        if (dest instanceof Inet4Address) {
-            return isToIpv4Destination(packet, (Inet4Address) dest);
-        } else if (dest instanceof Inet6Address) {
-            return isToIpv6Destination(packet, (Inet6Address) dest);
-        }
-        return false;
-    }
-
-    private static boolean isFromIpv4Source(byte[] packet, Inet4Address src) {
-        Ipv4Header header = extractIpv4Header(makeByteBuffer(packet));
-        return header != null && header.srcIp.equals(src);
-    }
-
-    private static boolean isFromIpv6Source(byte[] packet, Inet6Address src) {
-        Ipv6Header header = extractIpv6Header(makeByteBuffer(packet));
-        return header != null && header.srcIp.equals(src);
-    }
-
-    private static boolean isToIpv4Destination(byte[] packet, Inet4Address dest) {
-        Ipv4Header header = extractIpv4Header(makeByteBuffer(packet));
-        return header != null && header.dstIp.equals(dest);
-    }
-
-    private static boolean isToIpv6Destination(byte[] packet, Inet6Address dest) {
-        Ipv6Header header = extractIpv6Header(makeByteBuffer(packet));
-        return header != null && header.dstIp.equals(dest);
-    }
-
-    private static ByteBuffer makeByteBuffer(byte[] packet) {
-        return packet == null ? null : ByteBuffer.wrap(packet);
-    }
-
-    private static Ipv4Header extractIpv4Header(ByteBuffer buf) {
-        try {
-            return Struct.parse(Ipv4Header.class, buf);
-        } catch (IllegalArgumentException ignored) {
-            // It's fine that the passed in packet is malformed because it's could be sent
-            // by anybody.
-        }
-        return null;
-    }
-
-    private static Ipv6Header extractIpv6Header(ByteBuffer buf) {
-        try {
-            return Struct.parse(Ipv6Header.class, buf);
-        } catch (IllegalArgumentException ignored) {
-            // It's fine that the passed in packet is malformed because it's could be sent
-            // by anybody.
-        }
-        return null;
-    }
-
-    /** Returns the Prefix Information Options (PIO) extracted from an ICMPv6 RA message. */
-    public static List<PrefixInformationOption> getRaPios(byte[] raMsg) {
-        final ArrayList<PrefixInformationOption> pioList = new ArrayList<>();
-
-        if (raMsg == null) {
-            return pioList;
-        }
-
-        final ByteBuffer buf = ByteBuffer.wrap(raMsg);
-        final Ipv6Header ipv6Header = Struct.parse(Ipv6Header.class, buf);
-        if (ipv6Header.nextHeader != (byte) IPPROTO_ICMPV6) {
-            return pioList;
-        }
-
-        final Icmpv6Header icmpv6Header = Struct.parse(Icmpv6Header.class, buf);
-        if (icmpv6Header.type != (short) ICMPV6_ROUTER_ADVERTISEMENT) {
-            return pioList;
-        }
-
-        Struct.parse(RaHeader.class, buf);
-        while (buf.position() < raMsg.length) {
-            final int currentPos = buf.position();
-            final int type = Byte.toUnsignedInt(buf.get());
-            final int length = Byte.toUnsignedInt(buf.get());
-            if (type == ICMPV6_ND_OPTION_PIO) {
-                final ByteBuffer pioBuf =
-                        ByteBuffer.wrap(
-                                buf.array(),
-                                currentPos,
-                                Struct.getSize(PrefixInformationOption.class));
-                final PrefixInformationOption pio =
-                        Struct.parse(PrefixInformationOption.class, pioBuf);
-                pioList.add(pio);
-
-                // Move ByteBuffer position to the next option.
-                buf.position(currentPos + Struct.getSize(PrefixInformationOption.class));
-            } else {
-                // The length is in units of 8 octets.
-                buf.position(currentPos + (length * 8));
-            }
-        }
-        return pioList;
-    }
-
-    /**
-     * Sends a UDP message to a destination.
-     *
-     * @param dstAddress the IP address of the destination
-     * @param dstPort the port of the destination
-     * @param message the message in UDP payload
-     * @throws IOException if failed to send the message
-     */
-    public static void sendUdpMessage(InetAddress dstAddress, int dstPort, String message)
-            throws IOException {
-        SocketAddress dstSockAddr = new InetSocketAddress(dstAddress, dstPort);
-
-        try (DatagramSocket socket = new DatagramSocket()) {
-            socket.connect(dstSockAddr);
-
-            byte[] msgBytes = message.getBytes();
-            DatagramPacket packet = new DatagramPacket(msgBytes, msgBytes.length);
-
-            socket.send(packet);
-        }
-    }
-
-    public static boolean isInMulticastGroup(String interfaceName, Inet6Address address) {
-        final String cmd = "ip -6 maddr show dev " + interfaceName;
-        final String output = runShellCommandOrThrow(cmd);
-        final String addressStr = address.getHostAddress();
-        for (final String line : output.split("\\n")) {
-            if (line.contains(addressStr)) {
-                return true;
-            }
-        }
-        return false;
-    }
-
-    public static List<LinkAddress> getIpv6LinkAddresses(String interfaceName) {
-        List<LinkAddress> addresses = new ArrayList<>();
-        final String cmd = " ip -6 addr show dev " + interfaceName;
-        final String output = runShellCommandOrThrow(cmd);
-
-        for (final String line : output.split("\\n")) {
-            if (line.contains("inet6")) {
-                addresses.add(parseAddressLine(line));
-            }
-        }
-
-        return addresses;
-    }
-
-    /** Return the first discovered service of {@code serviceType}. */
-    public static NsdServiceInfo discoverService(NsdManager nsdManager, String serviceType)
-            throws Exception {
-        CompletableFuture<NsdServiceInfo> serviceInfoFuture = new CompletableFuture<>();
-        NsdManager.DiscoveryListener listener =
-                new DefaultDiscoveryListener() {
-                    @Override
-                    public void onServiceFound(NsdServiceInfo serviceInfo) {
-                        serviceInfoFuture.complete(serviceInfo);
-                    }
-                };
-        nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener);
-        try {
-            serviceInfoFuture.get(SERVICE_DISCOVERY_TIMEOUT.toMillis(), MILLISECONDS);
-        } finally {
-            nsdManager.stopServiceDiscovery(listener);
-        }
-
-        return serviceInfoFuture.get();
-    }
-
-    /**
-     * Returns the {@link NsdServiceInfo} when a service instance of {@code serviceType} gets lost.
-     */
-    public static NsdManager.DiscoveryListener discoverForServiceLost(
-            NsdManager nsdManager,
-            String serviceType,
-            CompletableFuture<NsdServiceInfo> serviceInfoFuture) {
-        NsdManager.DiscoveryListener listener =
-                new DefaultDiscoveryListener() {
-                    @Override
-                    public void onServiceLost(NsdServiceInfo serviceInfo) {
-                        serviceInfoFuture.complete(serviceInfo);
-                    }
-                };
-        nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener);
-        return listener;
-    }
-
-    /** Resolves the service. */
-    public static NsdServiceInfo resolveService(NsdManager nsdManager, NsdServiceInfo serviceInfo)
-            throws Exception {
-        return resolveServiceUntil(nsdManager, serviceInfo, s -> true);
-    }
-
-    /** Returns the first resolved service that satisfies the {@code predicate}. */
-    public static NsdServiceInfo resolveServiceUntil(
-            NsdManager nsdManager, NsdServiceInfo serviceInfo, Predicate<NsdServiceInfo> predicate)
-            throws Exception {
-        CompletableFuture<NsdServiceInfo> resolvedServiceInfoFuture = new CompletableFuture<>();
-        NsdManager.ServiceInfoCallback callback =
-                new DefaultServiceInfoCallback() {
-                    @Override
-                    public void onServiceUpdated(@NonNull NsdServiceInfo serviceInfo) {
-                        if (predicate.test(serviceInfo)) {
-                            resolvedServiceInfoFuture.complete(serviceInfo);
-                        }
-                    }
-                };
-        nsdManager.registerServiceInfoCallback(serviceInfo, directExecutor(), callback);
-        try {
-            return resolvedServiceInfoFuture.get(
-                    SERVICE_DISCOVERY_TIMEOUT.toMillis(), MILLISECONDS);
-        } finally {
-            nsdManager.unregisterServiceInfoCallback(callback);
-        }
-    }
-
-    public static String getPrefixesFromNetData(String netData) {
-        int startIdx = netData.indexOf("Prefixes:");
-        int endIdx = netData.indexOf("Routes:");
-        return netData.substring(startIdx, endIdx);
-    }
-
-    public static Network getThreadNetwork(Duration timeout) throws Exception {
-        CompletableFuture<Network> networkFuture = new CompletableFuture<>();
-        ConnectivityManager cm =
-                ApplicationProvider.getApplicationContext()
-                        .getSystemService(ConnectivityManager.class);
-        NetworkRequest.Builder networkRequestBuilder =
-                new NetworkRequest.Builder().addTransportType(NetworkCapabilities.TRANSPORT_THREAD);
-        // Before V, we need to explicitly set `NET_CAPABILITY_LOCAL_NETWORK` capability to request
-        // a Thread network.
-        if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
-            networkRequestBuilder.addCapability(NET_CAPABILITY_LOCAL_NETWORK);
-        }
-        NetworkRequest networkRequest = networkRequestBuilder.build();
-        ConnectivityManager.NetworkCallback networkCallback =
-                new ConnectivityManager.NetworkCallback() {
-                    @Override
-                    public void onAvailable(Network network) {
-                        networkFuture.complete(network);
-                    }
-                };
-        cm.registerNetworkCallback(networkRequest, networkCallback);
-        return networkFuture.get(timeout.toSeconds(), SECONDS);
-    }
-
-    /**
-     * Let the FTD join the specified Thread network and wait for border routing to be available.
-     *
-     * @return the OMR address
-     */
-    public static Inet6Address joinNetworkAndWaitForOmr(
-            FullThreadDevice ftd, ActiveOperationalDataset dataset) throws Exception {
-        ftd.factoryReset();
-        ftd.joinNetwork(dataset);
-        ftd.waitForStateAnyOf(List.of("router", "child"), JOIN_TIMEOUT);
-        waitFor(() -> ftd.getOmrAddress() != null, Duration.ofSeconds(60));
-        Inet6Address ftdOmr = ftd.getOmrAddress();
-        assertNotNull(ftdOmr);
-        return ftdOmr;
-    }
-
-    private static class DefaultDiscoveryListener implements NsdManager.DiscoveryListener {
-        @Override
-        public void onStartDiscoveryFailed(String serviceType, int errorCode) {}
-
-        @Override
-        public void onStopDiscoveryFailed(String serviceType, int errorCode) {}
-
-        @Override
-        public void onDiscoveryStarted(String serviceType) {}
-
-        @Override
-        public void onDiscoveryStopped(String serviceType) {}
-
-        @Override
-        public void onServiceFound(NsdServiceInfo serviceInfo) {}
-
-        @Override
-        public void onServiceLost(NsdServiceInfo serviceInfo) {}
-    }
-
-    private static class DefaultServiceInfoCallback implements NsdManager.ServiceInfoCallback {
-        @Override
-        public void onServiceInfoCallbackRegistrationFailed(int errorCode) {}
-
-        @Override
-        public void onServiceUpdated(@NonNull NsdServiceInfo serviceInfo) {}
-
-        @Override
-        public void onServiceLost() {}
-
-        @Override
-        public void onServiceInfoCallbackUnregistered() {}
-    }
-
-    /**
-     * Parses a line of output from "ip -6 addr show" into a {@link LinkAddress}.
-     *
-     * <p>Example line: "inet6 2001:db8:1:1::1/64 scope global deprecated"
-     */
-    private static LinkAddress parseAddressLine(String line) {
-        String[] parts = line.trim().split("\\s+");
-        String addressString = parts[1];
-        String[] pieces = addressString.split("/", 2);
-        int prefixLength = Integer.parseInt(pieces[1]);
-        final InetAddress address = InetAddresses.parseNumericAddress(pieces[0]);
-        long deprecationTimeMillis =
-                line.contains("deprecated")
-                        ? SystemClock.elapsedRealtime()
-                        : LinkAddress.LIFETIME_PERMANENT;
-
-        return new LinkAddress(
-                address,
-                prefixLength,
-                0 /* flags */,
-                0 /* scope */,
-                deprecationTimeMillis,
-                LinkAddress.LIFETIME_PERMANENT /* expirationTime */);
-    }
-}
diff --git a/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.kt b/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.kt
new file mode 100644
index 0000000..fa9855e
--- /dev/null
+++ b/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.kt
@@ -0,0 +1,598 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.thread.utils
+
+import android.Manifest.permission.MANAGE_TEST_NETWORKS
+import android.content.Context
+import android.net.ConnectivityManager
+import android.net.InetAddresses.parseNumericAddress
+import android.net.IpPrefix
+import android.net.LinkAddress
+import android.net.LinkProperties
+import android.net.MacAddress
+import android.net.Network
+import android.net.NetworkCapabilities
+import android.net.NetworkRequest
+import android.net.RouteInfo
+import android.net.TestNetworkInterface
+import android.net.nsd.NsdManager
+import android.net.nsd.NsdServiceInfo
+import android.net.thread.ActiveOperationalDataset
+import android.net.thread.ThreadNetworkController
+import android.os.Build
+import android.os.Handler
+import android.os.SystemClock
+import android.system.OsConstants
+import androidx.test.core.app.ApplicationProvider
+import com.android.compatibility.common.util.SystemUtil.runShellCommandOrThrow
+import com.android.net.module.util.NetworkStackConstants
+import com.android.net.module.util.Struct
+import com.android.net.module.util.structs.Icmpv4Header
+import com.android.net.module.util.structs.Icmpv6Header
+import com.android.net.module.util.structs.Ipv4Header
+import com.android.net.module.util.structs.Ipv6Header
+import com.android.net.module.util.structs.PrefixInformationOption
+import com.android.net.module.util.structs.RaHeader
+import com.android.testutils.TapPacketReader
+import com.android.testutils.TestNetworkTracker
+import com.android.testutils.initTestNetwork
+import com.android.testutils.runAsShell
+import com.android.testutils.waitForIdle
+import com.google.common.io.BaseEncoding
+import com.google.common.util.concurrent.MoreExecutors
+import com.google.common.util.concurrent.MoreExecutors.directExecutor
+import com.google.common.util.concurrent.SettableFuture
+import java.io.IOException
+import java.lang.Byte.toUnsignedInt
+import java.net.DatagramPacket
+import java.net.DatagramSocket
+import java.net.Inet4Address
+import java.net.Inet6Address
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import java.net.SocketAddress
+import java.nio.ByteBuffer
+import java.time.Duration
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.ExecutionException
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.TimeoutException
+import java.util.function.Predicate
+import java.util.function.Supplier
+import org.junit.Assert
+
+/** Utilities for Thread integration tests. */
+object IntegrationTestUtils {
+    // The timeout of join() after restarting ot-daemon. The device needs to send 6 Link Request
+    // every 5 seconds, followed by 4 Parent Request every second. So this value needs to be 40
+    // seconds to be safe
+    @JvmField
+    val RESTART_JOIN_TIMEOUT: Duration = Duration.ofSeconds(40)
+
+    @JvmField
+    val JOIN_TIMEOUT: Duration = Duration.ofSeconds(30)
+
+    @JvmField
+    val LEAVE_TIMEOUT: Duration = Duration.ofSeconds(2)
+
+    @JvmField
+    val CALLBACK_TIMEOUT: Duration = Duration.ofSeconds(1)
+
+    @JvmField
+    val SERVICE_DISCOVERY_TIMEOUT: Duration = Duration.ofSeconds(20)
+
+    // A valid Thread Active Operational Dataset generated from OpenThread CLI "dataset init new".
+    private val DEFAULT_DATASET_TLVS: ByteArray = BaseEncoding.base16().decode(
+        ("0E080000000000010000000300001335060004001FFFE002"
+                + "08ACC214689BC40BDF0708FD64DB1225F47E0B0510F26B31"
+                + "53760F519A63BAFDDFFC80D2AF030F4F70656E5468726561"
+                + "642D643961300102D9A00410A245479C836D551B9CA557F7"
+                + "B9D351B40C0402A0FFF8")
+    )
+
+    @JvmField
+    val DEFAULT_DATASET: ActiveOperationalDataset =
+        ActiveOperationalDataset.fromThreadTlvs(DEFAULT_DATASET_TLVS)
+
+    /**
+     * Waits for the given [Supplier] to be true until given timeout.
+     *
+     * @param condition the condition to check
+     * @param timeout the time to wait for the condition before throwing
+     * @throws TimeoutException if the condition is still not met when the timeout expires
+     */
+    @JvmStatic
+    @Throws(TimeoutException::class)
+    fun waitFor(condition: Supplier<Boolean>, timeout: Duration) {
+        val intervalMills: Long = 500
+        val timeoutMills = timeout.toMillis()
+
+        var i: Long = 0
+        while (i < timeoutMills) {
+            if (condition.get()) {
+                return
+            }
+            SystemClock.sleep(intervalMills)
+            i += intervalMills
+        }
+        if (condition.get()) {
+            return
+        }
+        throw TimeoutException("The condition failed to become true in $timeout")
+    }
+
+    /**
+     * Creates a [TapPacketReader] given the [TestNetworkInterface] and [Handler].
+     *
+     * @param testNetworkInterface the TUN interface of the test network
+     * @param handler the handler to process the packets
+     * @return the [TapPacketReader]
+     */
+    @JvmStatic
+    fun newPacketReader(
+        testNetworkInterface: TestNetworkInterface, handler: Handler
+    ): TapPacketReader {
+        val fd = testNetworkInterface.fileDescriptor.fileDescriptor
+        val reader = TapPacketReader(handler, fd, testNetworkInterface.mtu)
+        handler.post { reader.start() }
+        handler.waitForIdle(timeoutMs = 5000)
+        return reader
+    }
+
+    /**
+     * Waits for the Thread module to enter any state of the given `deviceRoles`.
+     *
+     * @param controller the [ThreadNetworkController]
+     * @param deviceRoles the desired device roles. See also [     ]
+     * @param timeout the time to wait for the expected state before throwing
+     * @return the [ThreadNetworkController.DeviceRole] after waiting
+     * @throws TimeoutException if the device hasn't become any of expected roles until the timeout
+     * expires
+     */
+    @JvmStatic
+    @Throws(TimeoutException::class)
+    fun waitForStateAnyOf(
+        controller: ThreadNetworkController, deviceRoles: List<Int>, timeout: Duration
+    ): Int {
+        val future = SettableFuture.create<Int>()
+        val callback = ThreadNetworkController.StateCallback { newRole: Int ->
+            if (deviceRoles.contains(newRole)) {
+                future.set(newRole)
+            }
+        }
+        controller.registerStateCallback(MoreExecutors.directExecutor(), callback)
+        try {
+            return future[timeout.toMillis(), TimeUnit.MILLISECONDS]
+        } catch (e: InterruptedException) {
+            throw TimeoutException(
+                "The device didn't become an expected role in $timeout: $e.message"
+            )
+        } catch (e: ExecutionException) {
+            throw TimeoutException(
+                "The device didn't become an expected role in $timeout: $e.message"
+            )
+        } finally {
+            controller.unregisterStateCallback(callback)
+        }
+    }
+
+    /**
+     * Polls for a packet from a given [TapPacketReader] that satisfies the `filter`.
+     *
+     * @param packetReader a TUN packet reader
+     * @param filter the filter to be applied on the packet
+     * @return the first IPv6 packet that satisfies the `filter`. If it has waited for more
+     * than 3000ms to read the next packet, the method will return null
+     */
+    @JvmStatic
+    fun pollForPacket(packetReader: TapPacketReader, filter: Predicate<ByteArray>): ByteArray? {
+        var packet: ByteArray?
+        while ((packetReader.poll(3000 /* timeoutMs */, filter).also { packet = it }) != null) {
+            return packet
+        }
+        return null
+    }
+
+    /** Returns `true` if `packet` is an ICMPv4 packet of given `type`.  */
+    @JvmStatic
+    fun isExpectedIcmpv4Packet(packet: ByteArray, type: Int): Boolean {
+        val buf = makeByteBuffer(packet)
+        val header = extractIpv4Header(buf) ?: return false
+        if (header.protocol != OsConstants.IPPROTO_ICMP.toByte()) {
+            return false
+        }
+        try {
+            return Struct.parse(Icmpv4Header::class.java, buf).type == type.toShort()
+        } catch (ignored: IllegalArgumentException) {
+            // It's fine that the passed in packet is malformed because it's could be sent
+            // by anybody.
+        }
+        return false
+    }
+
+    /** Returns `true` if `packet` is an ICMPv6 packet of given `type`.  */
+    @JvmStatic
+    fun isExpectedIcmpv6Packet(packet: ByteArray, type: Int): Boolean {
+        val buf = makeByteBuffer(packet)
+        val header = extractIpv6Header(buf) ?: return false
+        if (header.nextHeader != OsConstants.IPPROTO_ICMPV6.toByte()) {
+            return false
+        }
+        try {
+            return Struct.parse(Icmpv6Header::class.java, buf).type == type.toShort()
+        } catch (ignored: IllegalArgumentException) {
+            // It's fine that the passed in packet is malformed because it's could be sent
+            // by anybody.
+        }
+        return false
+    }
+
+    @JvmStatic
+    fun isFrom(packet: ByteArray, src: InetAddress): Boolean {
+        when (src) {
+            is Inet4Address -> return isFromIpv4Source(packet, src)
+            is Inet6Address -> return isFromIpv6Source(packet, src)
+            else -> return false
+        }
+    }
+
+    @JvmStatic
+    fun isTo(packet: ByteArray, dest: InetAddress): Boolean {
+        when (dest) {
+            is Inet4Address -> return isToIpv4Destination(packet, dest)
+            is Inet6Address -> return isToIpv6Destination(packet, dest)
+            else -> return false
+        }
+    }
+
+    private fun isFromIpv4Source(packet: ByteArray, src: Inet4Address): Boolean {
+        val header = extractIpv4Header(makeByteBuffer(packet))
+        return header?.srcIp == src
+    }
+
+    private fun isFromIpv6Source(packet: ByteArray, src: Inet6Address): Boolean {
+        val header = extractIpv6Header(makeByteBuffer(packet))
+        return header?.srcIp == src
+    }
+
+    private fun isToIpv4Destination(packet: ByteArray, dest: Inet4Address): Boolean {
+        val header = extractIpv4Header(makeByteBuffer(packet))
+        return header?.dstIp == dest
+    }
+
+    private fun isToIpv6Destination(packet: ByteArray, dest: Inet6Address): Boolean {
+        val header = extractIpv6Header(makeByteBuffer(packet))
+        return header?.dstIp == dest
+    }
+
+    private fun makeByteBuffer(packet: ByteArray): ByteBuffer {
+        return ByteBuffer.wrap(packet)
+    }
+
+    private fun extractIpv4Header(buf: ByteBuffer): Ipv4Header? {
+        try {
+            return Struct.parse(Ipv4Header::class.java, buf)
+        } catch (ignored: IllegalArgumentException) {
+            // It's fine that the passed in packet is malformed because it's could be sent
+            // by anybody.
+        }
+        return null
+    }
+
+    private fun extractIpv6Header(buf: ByteBuffer): Ipv6Header? {
+        try {
+            return Struct.parse(Ipv6Header::class.java, buf)
+        } catch (ignored: IllegalArgumentException) {
+            // It's fine that the passed in packet is malformed because it's could be sent
+            // by anybody.
+        }
+        return null
+    }
+
+    /** Returns the Prefix Information Options (PIO) extracted from an ICMPv6 RA message.  */
+    @JvmStatic
+    fun getRaPios(raMsg: ByteArray?): List<PrefixInformationOption> {
+        val pioList = ArrayList<PrefixInformationOption>()
+
+        raMsg ?: return pioList
+
+        val buf = ByteBuffer.wrap(raMsg)
+        val ipv6Header = Struct.parse(Ipv6Header::class.java, buf)
+        if (ipv6Header.nextHeader != OsConstants.IPPROTO_ICMPV6.toByte()) {
+            return pioList
+        }
+
+        val icmpv6Header = Struct.parse(Icmpv6Header::class.java, buf)
+        if (icmpv6Header.type != NetworkStackConstants.ICMPV6_ROUTER_ADVERTISEMENT.toShort()) {
+            return pioList
+        }
+
+        Struct.parse(RaHeader::class.java, buf)
+        while (buf.position() < raMsg.size) {
+            val currentPos = buf.position()
+            val type = toUnsignedInt(buf.get())
+            val length = toUnsignedInt(buf.get())
+            if (type == NetworkStackConstants.ICMPV6_ND_OPTION_PIO) {
+                val pioBuf = ByteBuffer.wrap(
+                    buf.array(), currentPos, Struct.getSize(PrefixInformationOption::class.java)
+                )
+                val pio = Struct.parse(PrefixInformationOption::class.java, pioBuf)
+                pioList.add(pio)
+
+                // Move ByteBuffer position to the next option.
+                buf.position(
+                    currentPos + Struct.getSize(PrefixInformationOption::class.java)
+                )
+            } else {
+                // The length is in units of 8 octets.
+                buf.position(currentPos + (length * 8))
+            }
+        }
+        return pioList
+    }
+
+    /**
+     * Sends a UDP message to a destination.
+     *
+     * @param dstAddress the IP address of the destination
+     * @param dstPort the port of the destination
+     * @param message the message in UDP payload
+     * @throws IOException if failed to send the message
+     */
+    @JvmStatic
+    @Throws(IOException::class)
+    fun sendUdpMessage(dstAddress: InetAddress, dstPort: Int, message: String) {
+        val dstSockAddr: SocketAddress = InetSocketAddress(dstAddress, dstPort)
+
+        DatagramSocket().use { socket ->
+            socket.connect(dstSockAddr)
+            val msgBytes = message.toByteArray()
+            val packet = DatagramPacket(msgBytes, msgBytes.size)
+            socket.send(packet)
+        }
+    }
+
+    @JvmStatic
+    fun isInMulticastGroup(interfaceName: String, address: Inet6Address): Boolean {
+        val cmd = "ip -6 maddr show dev $interfaceName"
+        val output: String = runShellCommandOrThrow(cmd)
+        val addressStr = address.hostAddress
+        for (line in output.split("\\n".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()) {
+            if (line.contains(addressStr)) {
+                return true
+            }
+        }
+        return false
+    }
+
+    @JvmStatic
+    fun getIpv6LinkAddresses(interfaceName: String): List<LinkAddress> {
+        val addresses: MutableList<LinkAddress> = ArrayList()
+        val cmd = " ip -6 addr show dev $interfaceName"
+        val output: String = runShellCommandOrThrow(cmd)
+
+        for (line in output.split("\\n".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()) {
+            if (line.contains("inet6")) {
+                addresses.add(parseAddressLine(line))
+            }
+        }
+
+        return addresses
+    }
+
+    /** Return the first discovered service of `serviceType`.  */
+    @JvmStatic
+    @Throws(Exception::class)
+    fun discoverService(nsdManager: NsdManager, serviceType: String): NsdServiceInfo {
+        val serviceInfoFuture = CompletableFuture<NsdServiceInfo>()
+        val listener: NsdManager.DiscoveryListener = object : DefaultDiscoveryListener() {
+            override fun onServiceFound(serviceInfo: NsdServiceInfo) {
+                serviceInfoFuture.complete(serviceInfo)
+            }
+        }
+        nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener)
+        try {
+            serviceInfoFuture[SERVICE_DISCOVERY_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS]
+        } finally {
+            nsdManager.stopServiceDiscovery(listener)
+        }
+
+        return serviceInfoFuture.get()
+    }
+
+    /**
+     * Returns the [NsdServiceInfo] when a service instance of `serviceType` gets lost.
+     */
+    @JvmStatic
+    fun discoverForServiceLost(
+        nsdManager: NsdManager,
+        serviceType: String?,
+        serviceInfoFuture: CompletableFuture<NsdServiceInfo?>
+    ): NsdManager.DiscoveryListener {
+        val listener: NsdManager.DiscoveryListener = object : DefaultDiscoveryListener() {
+            override fun onServiceLost(serviceInfo: NsdServiceInfo): Unit {
+                serviceInfoFuture.complete(serviceInfo)
+            }
+        }
+        nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener)
+        return listener
+    }
+
+    /** Resolves the service.  */
+    @JvmStatic
+    @Throws(Exception::class)
+    fun resolveService(nsdManager: NsdManager, serviceInfo: NsdServiceInfo): NsdServiceInfo {
+        return resolveServiceUntil(nsdManager, serviceInfo) { true }
+    }
+
+    /** Returns the first resolved service that satisfies the `predicate`.  */
+    @JvmStatic
+    @Throws(Exception::class)
+    fun resolveServiceUntil(
+        nsdManager: NsdManager, serviceInfo: NsdServiceInfo, predicate: Predicate<NsdServiceInfo>
+    ): NsdServiceInfo {
+        val resolvedServiceInfoFuture = CompletableFuture<NsdServiceInfo>()
+        val callback: NsdManager.ServiceInfoCallback = object : DefaultServiceInfoCallback() {
+            override fun onServiceUpdated(serviceInfo: NsdServiceInfo) {
+                if (predicate.test(serviceInfo)) {
+                    resolvedServiceInfoFuture.complete(serviceInfo)
+                }
+            }
+        }
+        nsdManager.registerServiceInfoCallback(serviceInfo, directExecutor(), callback)
+        try {
+            return resolvedServiceInfoFuture[
+                SERVICE_DISCOVERY_TIMEOUT.toMillis(),
+                TimeUnit.MILLISECONDS]
+        } finally {
+            nsdManager.unregisterServiceInfoCallback(callback)
+        }
+    }
+
+    @JvmStatic
+    fun getPrefixesFromNetData(netData: String): String {
+        val startIdx = netData.indexOf("Prefixes:")
+        val endIdx = netData.indexOf("Routes:")
+        return netData.substring(startIdx, endIdx)
+    }
+
+    @JvmStatic
+    @Throws(Exception::class)
+    fun getThreadNetwork(timeout: Duration): Network {
+        val networkFuture = CompletableFuture<Network>()
+        val cm =
+            ApplicationProvider.getApplicationContext<Context>()
+                .getSystemService(ConnectivityManager::class.java)
+        val networkRequestBuilder =
+            NetworkRequest.Builder().addTransportType(NetworkCapabilities.TRANSPORT_THREAD)
+        // Before V, we need to explicitly set `NET_CAPABILITY_LOCAL_NETWORK` capability to request
+        // a Thread network.
+        if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
+            networkRequestBuilder.addCapability(NetworkCapabilities.NET_CAPABILITY_LOCAL_NETWORK)
+        }
+        val networkRequest = networkRequestBuilder.build()
+        val networkCallback: ConnectivityManager.NetworkCallback =
+            object : ConnectivityManager.NetworkCallback() {
+                override fun onAvailable(network: Network) {
+                    networkFuture.complete(network)
+                }
+            }
+        cm.registerNetworkCallback(networkRequest, networkCallback)
+        return networkFuture[timeout.toSeconds(), TimeUnit.SECONDS]
+    }
+
+    /**
+     * Let the FTD join the specified Thread network and wait for border routing to be available.
+     *
+     * @return the OMR address
+     */
+    @JvmStatic
+    @Throws(Exception::class)
+    fun joinNetworkAndWaitForOmr(
+        ftd: FullThreadDevice, dataset: ActiveOperationalDataset
+    ): Inet6Address {
+        ftd.factoryReset()
+        ftd.joinNetwork(dataset)
+        ftd.waitForStateAnyOf(listOf("router", "child"), JOIN_TIMEOUT)
+        waitFor({ ftd.omrAddress != null }, Duration.ofSeconds(60))
+        Assert.assertNotNull(ftd.omrAddress)
+        return ftd.omrAddress
+    }
+
+    private open class DefaultDiscoveryListener : NsdManager.DiscoveryListener {
+        override fun onStartDiscoveryFailed(serviceType: String, errorCode: Int) {}
+        override fun onStopDiscoveryFailed(serviceType: String, errorCode: Int) {}
+        override fun onDiscoveryStarted(serviceType: String) {}
+        override fun onDiscoveryStopped(serviceType: String) {}
+        override fun onServiceFound(serviceInfo: NsdServiceInfo) {}
+        override fun onServiceLost(serviceInfo: NsdServiceInfo) {}
+    }
+
+    private open class DefaultServiceInfoCallback : NsdManager.ServiceInfoCallback {
+        override fun onServiceInfoCallbackRegistrationFailed(errorCode: Int) {}
+        override fun onServiceUpdated(serviceInfo: NsdServiceInfo) {}
+        override fun onServiceLost(): Unit {}
+        override fun onServiceInfoCallbackUnregistered() {}
+    }
+
+    /**
+     * Parses a line of output from "ip -6 addr show" into a [LinkAddress].
+     *
+     * Example line: "inet6 2001:db8:1:1::1/64 scope global deprecated"
+     */
+    private fun parseAddressLine(line: String): LinkAddress {
+        val parts = line.split("\\s+".toRegex()).filter { it.isNotEmpty() }.toTypedArray()
+        val addressString = parts[1]
+        val pieces = addressString.split("/".toRegex(), limit = 2).toTypedArray()
+        val prefixLength = pieces[1].toInt()
+        val address = parseNumericAddress(pieces[0])
+        val deprecationTimeMillis =
+            if (line.contains("deprecated")) SystemClock.elapsedRealtime()
+            else LinkAddress.LIFETIME_PERMANENT
+
+        return LinkAddress(
+            address, prefixLength,
+            0 /* flags */, 0 /* scope */,
+            deprecationTimeMillis, LinkAddress.LIFETIME_PERMANENT /* expirationTime */
+        )
+    }
+
+    @JvmStatic
+    @JvmOverloads
+    fun startInfraDeviceAndWaitForOnLinkAddr(
+        tapPacketReader: TapPacketReader,
+        macAddress: MacAddress = MacAddress.fromString("1:2:3:4:5:6")
+    ): InfraNetworkDevice {
+        val infraDevice = InfraNetworkDevice(macAddress, tapPacketReader)
+        infraDevice.runSlaac(Duration.ofSeconds(60))
+        requireNotNull(infraDevice.ipv6Addr)
+        return infraDevice
+    }
+
+    @JvmStatic
+    @Throws(java.lang.Exception::class)
+    fun setUpInfraNetwork(
+        context: Context, controller: ThreadNetworkControllerWrapper
+    ): TestNetworkTracker {
+        val lp = LinkProperties()
+
+        // TODO: use a fake DNS server
+        lp.setDnsServers(listOf(parseNumericAddress("8.8.8.8")))
+        // NAT64 feature requires the infra network to have an IPv4 default route.
+        lp.addRoute(
+            RouteInfo(
+                IpPrefix("0.0.0.0/0") /* destination */,
+                null /* gateway */,
+                null /* iface */,
+                RouteInfo.RTN_UNICAST, 1500 /* mtu */
+            )
+        )
+        val infraNetworkTracker: TestNetworkTracker =
+            runAsShell(
+                MANAGE_TEST_NETWORKS,
+                supplier = { initTestNetwork(context, lp, setupTimeoutMs = 5000) })
+        val infraNetworkName: String = infraNetworkTracker.testIface.getInterfaceName()
+        controller.setTestNetworkAsUpstreamAndWait(infraNetworkName)
+
+        return infraNetworkTracker
+    }
+
+    @JvmStatic
+    fun tearDownInfraNetwork(testNetworkTracker: TestNetworkTracker) {
+        runAsShell(MANAGE_TEST_NETWORKS) { testNetworkTracker.teardown() }
+    }
+}