Snap for 12327203 from a6648cbbf2a3d9aa94c9fd2bddec33e1e0dd774e to 24Q4-release
Change-Id: I55bdf605ac3c0e02f0c9f2b16003dd9a9aaa4533
diff --git a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
index 1eb6255..d5c6a8e 100644
--- a/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
+++ b/Tethering/tests/integration/base/android/net/EthernetTetheringTestBase.java
@@ -568,6 +568,12 @@
return nif.getMTU();
}
+ protected int getIndexByName(String ifaceName) throws SocketException {
+ NetworkInterface nif = NetworkInterface.getByName(ifaceName);
+ assertNotNull("Can't get NetworkInterface object for " + ifaceName, nif);
+ return nif.getIndex();
+ }
+
protected TapPacketReader makePacketReader(final TestNetworkInterface iface) throws Exception {
FileDescriptor fd = iface.getFileDescriptor().getFileDescriptor();
return makePacketReader(fd, getMTU(iface));
@@ -968,6 +974,11 @@
return Struct.parse(Ipv6Header.class, ByteBuffer.wrap(expectedPacket)).srcIp;
}
+ protected String getUpstreamInterfaceName() {
+ if (mUpstreamReader == null) return null;
+ return mUpstreamTracker.getTestIface().getInterfaceName();
+ }
+
protected <T> List<T> toList(T... array) {
return Arrays.asList(array);
}
diff --git a/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java b/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java
index 049f5f0..32b2f3e 100644
--- a/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java
+++ b/Tethering/tests/integration/src/android/net/EthernetTetheringTest.java
@@ -1066,24 +1066,34 @@
runUdp4Test();
}
- private ClatEgress4Value getClatEgress4Value() throws Exception {
+ private ClatEgress4Value getClatEgress4Value(int clatIfaceIndex) throws Exception {
// Command: dumpsys connectivity clatEgress4RawBpfMap
final String[] args = new String[] {DUMPSYS_CLAT_RAWMAP_EGRESS4_ARG};
final HashMap<ClatEgress4Key, ClatEgress4Value> egress4Map = pollRawMapFromDump(
ClatEgress4Key.class, ClatEgress4Value.class, Context.CONNECTIVITY_SERVICE, args);
assertNotNull(egress4Map);
- assertEquals(1, egress4Map.size());
- return egress4Map.entrySet().iterator().next().getValue();
+ for (Map.Entry<ClatEgress4Key, ClatEgress4Value> entry : egress4Map.entrySet()) {
+ ClatEgress4Key key = entry.getKey();
+ if (key.iif == clatIfaceIndex) {
+ return entry.getValue();
+ }
+ }
+ return null;
}
- private ClatIngress6Value getClatIngress6Value() throws Exception {
+ private ClatIngress6Value getClatIngress6Value(int ifaceIndex) throws Exception {
// Command: dumpsys connectivity clatIngress6RawBpfMap
final String[] args = new String[] {DUMPSYS_CLAT_RAWMAP_INGRESS6_ARG};
final HashMap<ClatIngress6Key, ClatIngress6Value> ingress6Map = pollRawMapFromDump(
ClatIngress6Key.class, ClatIngress6Value.class, Context.CONNECTIVITY_SERVICE, args);
assertNotNull(ingress6Map);
- assertEquals(1, ingress6Map.size());
- return ingress6Map.entrySet().iterator().next().getValue();
+ for (Map.Entry<ClatIngress6Key, ClatIngress6Value> entry : ingress6Map.entrySet()) {
+ ClatIngress6Key key = entry.getKey();
+ if (key.iif == ifaceIndex) {
+ return entry.getValue();
+ }
+ }
+ return null;
}
/**
@@ -1115,8 +1125,13 @@
final Inet6Address clatIp6 = getClatIpv6Address(tester, tethered);
// Get current values before sending packets.
- final ClatEgress4Value oldEgress4 = getClatEgress4Value();
- final ClatIngress6Value oldIngress6 = getClatIngress6Value();
+ final String ifaceName = getUpstreamInterfaceName();
+ final int ifaceIndex = getIndexByName(ifaceName);
+ final int clatIfaceIndex = getIndexByName("v4-" + ifaceName);
+ final ClatEgress4Value oldEgress4 = getClatEgress4Value(clatIfaceIndex);
+ final ClatIngress6Value oldIngress6 = getClatIngress6Value(ifaceIndex);
+ assertNotNull(oldEgress4);
+ assertNotNull(oldIngress6);
// Send an IPv4 UDP packet in original direction.
// IPv4 packet -- CLAT translation --> IPv6 packet
@@ -1145,8 +1160,10 @@
ByteBuffer.wrap(payload), l2mtu);
// After sending test packets, get stats again to verify their differences.
- final ClatEgress4Value newEgress4 = getClatEgress4Value();
- final ClatIngress6Value newIngress6 = getClatIngress6Value();
+ final ClatEgress4Value newEgress4 = getClatEgress4Value(clatIfaceIndex);
+ final ClatIngress6Value newIngress6 = getClatIngress6Value(ifaceIndex);
+ assertNotNull(newEgress4);
+ assertNotNull(newIngress6);
assertEquals(RX_UDP_PACKET_COUNT + fragPktCnt, newIngress6.packets - oldIngress6.packets);
assertEquals(RX_UDP_PACKET_COUNT * RX_UDP_PACKET_SIZE + fragRxBytes,
diff --git a/bpf/headers/include/bpf_helpers.h b/bpf/headers/include/bpf_helpers.h
index 1a9fd31..ac5ffda 100644
--- a/bpf/headers/include/bpf_helpers.h
+++ b/bpf/headers/include/bpf_helpers.h
@@ -291,6 +291,12 @@
bpf_ringbuf_submit_unsafe(v, 0); \
}
+#define DEFINE_BPF_RINGBUF(the_map, ValueType, size_bytes, usr, grp, md) \
+ DEFINE_BPF_RINGBUF_EXT(the_map, ValueType, size_bytes, usr, grp, md, \
+ DEFAULT_BPF_MAP_SELINUX_CONTEXT, DEFAULT_BPF_MAP_PIN_SUBDIR, \
+ PRIVATE, BPFLOADER_MIN_VER, BPFLOADER_MAX_VER, \
+ LOAD_ON_ENG, LOAD_ON_USER, LOAD_ON_USERDEBUG)
+
/* There exist buggy kernels with pre-T OS, that due to
* kernel patch "[ALPS05162612] bpf: fix ubsan error"
* do not support userspace writes into non-zero index of bpf map arrays.
@@ -349,11 +355,17 @@
#error "Bpf Map UID must be left at default of AID_ROOT for BpfLoader prior to v0.28"
#endif
-#define DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md) \
- DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md, \
- DEFAULT_BPF_MAP_SELINUX_CONTEXT, DEFAULT_BPF_MAP_PIN_SUBDIR, PRIVATE, \
- BPFLOADER_MIN_VER, BPFLOADER_MAX_VER, LOAD_ON_ENG, \
- LOAD_ON_USER, LOAD_ON_USERDEBUG)
+// for maps not meant to be accessed from userspace
+#define DEFINE_BPF_MAP_KERNEL_INTERNAL(the_map, TYPE, KeyType, ValueType, num_entries) \
+ DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, AID_ROOT, AID_ROOT, \
+ 0000, "fs_bpf_loader", "", PRIVATE, BPFLOADER_MIN_VER, BPFLOADER_MAX_VER, \
+ LOAD_ON_ENG, LOAD_ON_USER, LOAD_ON_USERDEBUG)
+
+#define DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md) \
+ DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md, \
+ DEFAULT_BPF_MAP_SELINUX_CONTEXT, DEFAULT_BPF_MAP_PIN_SUBDIR, \
+ PRIVATE, BPFLOADER_MIN_VER, BPFLOADER_MAX_VER, \
+ LOAD_ON_ENG, LOAD_ON_USER, LOAD_ON_USERDEBUG)
#define DEFINE_BPF_MAP(the_map, TYPE, KeyType, ValueType, num_entries) \
DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \
diff --git a/bpf/loader/NetBpfLoad.cpp b/bpf/loader/NetBpfLoad.cpp
index 04393a0..4767dfa 100644
--- a/bpf/loader/NetBpfLoad.cpp
+++ b/bpf/loader/NetBpfLoad.cpp
@@ -60,7 +60,14 @@
#include "bpf_map_def.h"
using android::base::EndsWith;
+using android::base::GetIntProperty;
+using android::base::GetProperty;
+using android::base::InitLogging;
+using android::base::KernelLogger;
+using android::base::SetProperty;
+using android::base::Split;
using android::base::StartsWith;
+using android::base::Tokenize;
using android::base::unique_fd;
using std::ifstream;
using std::ios;
@@ -90,6 +97,8 @@
net_shared, // (T+) fs_bpf_net_shared /sys/fs/bpf/net_shared
netd_readonly, // (T+) fs_bpf_netd_readonly /sys/fs/bpf/netd_readonly
netd_shared, // (T+) fs_bpf_netd_shared /sys/fs/bpf/netd_shared
+ loader, // (U+) fs_bpf_loader /sys/fs/bpf/loader
+ // on T due to lack of sepolicy/genfscon rules it behaves simply as 'fs_bpf'
};
static constexpr domain AllDomains[] = {
@@ -99,6 +108,7 @@
domain::net_shared,
domain::netd_readonly,
domain::netd_shared,
+ domain::loader,
};
static constexpr bool specified(domain d) {
@@ -112,7 +122,7 @@
// Returns the build type string (from ro.build.type).
const std::string& getBuildType() {
- static std::string t = android::base::GetProperty("ro.build.type", "unknown");
+ static std::string t = GetProperty("ro.build.type", "unknown");
return t;
}
@@ -144,6 +154,7 @@
case domain::net_shared: return "fs_bpf_net_shared";
case domain::netd_readonly: return "fs_bpf_netd_readonly";
case domain::netd_shared: return "fs_bpf_netd_shared";
+ case domain::loader: return "fs_bpf_loader";
}
}
@@ -167,6 +178,7 @@
case domain::net_shared: return "net_shared/";
case domain::netd_readonly: return "netd_readonly/";
case domain::netd_shared: return "netd_shared/";
+ case domain::loader: return "loader/";
}
};
@@ -184,7 +196,7 @@
static string pathToObjName(const string& path) {
// extract everything after the final slash, ie. this is the filename 'foo@1.o' or 'bar.o'
- string filename = android::base::Split(path, "/").back();
+ string filename = Split(path, "/").back();
// strip off everything from the final period onwards (strip '.o' suffix), ie. 'foo@1' or 'bar'
string name = filename.substr(0, filename.find_last_of('.'));
// strip any potential @1 suffix, this will leave us with just 'foo' or 'bar'
@@ -1016,7 +1028,7 @@
if (!fd.ok()) {
if (log_buf.size()) {
- vector<string> lines = android::base::Split(log_buf.data(), "\n");
+ vector<string> lines = Split(log_buf.data(), "\n");
ALOGW("BPF_PROG_LOAD - BEGIN log_buf contents:");
for (const auto& line : lines) ALOGW("%s", line.c_str());
@@ -1247,7 +1259,7 @@
// to include a newline to match 'echo "value" > /proc/sys/...foo' behaviour,
// which is usually how kernel devs test the actual sysctl interfaces.
static int writeProcSysFile(const char *filename, const char *value) {
- base::unique_fd fd(open(filename, O_WRONLY | O_CLOEXEC));
+ unique_fd fd(open(filename, O_WRONLY | O_CLOEXEC));
if (fd < 0) {
const int err = errno;
ALOGE("open('%s', O_WRONLY | O_CLOEXEC) -> %s", filename, strerror(err));
@@ -1324,7 +1336,7 @@
}
static bool hasGSM() {
- static string ph = base::GetProperty("gsm.current.phone-type", "");
+ static string ph = GetProperty("gsm.current.phone-type", "");
static bool gsm = (ph != "");
static bool logged = false;
if (!logged) {
@@ -1337,7 +1349,7 @@
static bool isTV() {
if (hasGSM()) return false; // TVs don't do GSM
- static string key = base::GetProperty("ro.oem.key1", "");
+ static string key = GetProperty("ro.oem.key1", "");
static bool tv = StartsWith(key, "ATV00");
static bool logged = false;
if (!logged) {
@@ -1348,10 +1360,10 @@
}
static bool isWear() {
- static string wearSdkStr = base::GetProperty("ro.cw_build.wear_sdk.version", "");
- static int wearSdkInt = base::GetIntProperty("ro.cw_build.wear_sdk.version", 0);
- static string buildChars = base::GetProperty("ro.build.characteristics", "");
- static vector<string> v = base::Tokenize(buildChars, ",");
+ static string wearSdkStr = GetProperty("ro.cw_build.wear_sdk.version", "");
+ static int wearSdkInt = GetIntProperty("ro.cw_build.wear_sdk.version", 0);
+ static string buildChars = GetProperty("ro.build.characteristics", "");
+ static vector<string> v = Tokenize(buildChars, ",");
static bool watch = (std::find(v.begin(), v.end(), "watch") != v.end());
static bool wear = (wearSdkInt > 0) || watch;
static bool logged = false;
@@ -1368,7 +1380,7 @@
// Any released device will have codename REL instead of a 'real' codename.
// For safety: default to 'REL' so we default to unreleased=false on failure.
- const bool unreleased = (base::GetProperty("ro.build.version.codename", "REL") != "REL");
+ const bool unreleased = (GetProperty("ro.build.version.codename", "REL") != "REL");
// goog/main device_api_level is bumped *way* before aosp/main api level
// (the latter only gets bumped during the push of goog/main to aosp/main)
@@ -1399,7 +1411,7 @@
const bool isAtLeastV = (effective_api_level >= __ANDROID_API_V__);
const bool isAtLeastW = (effective_api_level > __ANDROID_API_V__); // TODO: switch to W
- const int first_api_level = base::GetIntProperty("ro.board.first_api_level", effective_api_level);
+ const int first_api_level = GetIntProperty("ro.board.first_api_level", effective_api_level);
// last in U QPR2 beta1
const bool has_platform_bpfloader_rc = exists("/system/etc/init/bpfloader.rc");
@@ -1620,7 +1632,7 @@
int key = 1;
int value = 123;
- base::unique_fd map(
+ unique_fd map(
createMap(BPF_MAP_TYPE_ARRAY, sizeof(key), sizeof(value), 2, 0));
if (writeToMapEntry(map, &key, &value, BPF_ANY)) {
ALOGE("Critical kernel bug - failure to write into index 1 of 2 element bpf map array.");
@@ -1652,11 +1664,11 @@
} // namespace android
int main(int argc, char** argv, char * const envp[]) {
- android::base::InitLogging(argv, &android::base::KernelLogger);
+ InitLogging(argv, &KernelLogger);
if (argc == 2 && !strcmp(argv[1], "done")) {
// we're being re-exec'ed from platform bpfloader to 'finalize' things
- if (!android::base::SetProperty("bpf.progs_loaded", "1")) {
+ if (!SetProperty("bpf.progs_loaded", "1")) {
ALOGE("Failed to set bpf.progs_loaded property to 1.");
return 125;
}
diff --git a/bpf/progs/bpf_net_helpers.h b/bpf/progs/bpf_net_helpers.h
index a86c3e6..a5664ba 100644
--- a/bpf/progs/bpf_net_helpers.h
+++ b/bpf/progs/bpf_net_helpers.h
@@ -139,6 +139,24 @@
if (skb->data_end - skb->data < len) bpf_skb_pull_data(skb, len);
}
+// anti-compiler-optimizer no-op: explicitly force full calculation of 'v'
+//
+// The use for this is to force full calculation of a complex arithmetic (likely binary
+// bitops) value, and then check the result only once (thus likely reducing the number
+// of required conditional jump instructions that badly affect bpf verifier runtime)
+//
+// The compiler cannot look into the assembly statement, so it doesn't know it does nothing.
+// Since the statement takes 'v' as both input and output in a register (+r),
+// the compiler must fully calculate the precise value of 'v' before this,
+// and must use the (possibly modified) value of 'v' afterwards (thus cannot
+// do funky optimizations to use partial results from before the asm).
+//
+// As this is not flagged 'volatile' this may still be moved out of a loop,
+// or even entirely optimized out if 'v' is never used afterwards.
+//
+// See: https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html
+#define COMPILER_FORCE_CALCULATION(v) asm ("" : "+r" (v))
+
struct egress_bool { bool egress; };
#define INGRESS ((struct egress_bool){ .egress = false })
#define EGRESS ((struct egress_bool){ .egress = true })
diff --git a/bpf/progs/dscpPolicy.c b/bpf/progs/dscpPolicy.c
index 39f2961..de9723d 100644
--- a/bpf/progs/dscpPolicy.c
+++ b/bpf/progs/dscpPolicy.c
@@ -25,8 +25,8 @@
// The cache is never read nor written by userspace and is indexed by socket cookie % CACHE_MAP_SIZE
#define CACHE_MAP_SIZE 32 // should be a power of two so we can % cheaply
-DEFINE_BPF_MAP_GRO(socket_policy_cache_map, PERCPU_ARRAY, uint32_t, RuleEntry, CACHE_MAP_SIZE,
- AID_SYSTEM)
+DEFINE_BPF_MAP_KERNEL_INTERNAL(socket_policy_cache_map, PERCPU_ARRAY, uint32_t, RuleEntry,
+ CACHE_MAP_SIZE)
DEFINE_BPF_MAP_GRW(ipv4_dscp_policies_map, ARRAY, uint32_t, DscpPolicy, MAX_POLICIES, AID_SYSTEM)
DEFINE_BPF_MAP_GRW(ipv6_dscp_policies_map, ARRAY, uint32_t, DscpPolicy, MAX_POLICIES, AID_SYSTEM)
@@ -113,14 +113,30 @@
// this array lookup cannot actually fail
RuleEntry* existing_rule = bpf_socket_policy_cache_map_lookup_elem(&cacheid);
- if (existing_rule &&
- v6_equal(src_ip, existing_rule->src_ip) &&
- v6_equal(dst_ip, existing_rule->dst_ip) &&
- skb->ifindex == existing_rule->ifindex &&
- sport == existing_rule->src_port &&
- dport == existing_rule->dst_port &&
- protocol == existing_rule->proto) {
- if (existing_rule->dscp_val < 0) return;
+ if (!existing_rule) return; // impossible
+
+ uint64_t nomatch = 0;
+ nomatch |= v6_not_equal(src_ip, existing_rule->src_ip);
+ nomatch |= v6_not_equal(dst_ip, existing_rule->dst_ip);
+ nomatch |= (skb->ifindex ^ existing_rule->ifindex);
+ nomatch |= (sport ^ existing_rule->src_port);
+ nomatch |= (dport ^ existing_rule->dst_port);
+ nomatch |= (protocol ^ existing_rule->proto);
+ COMPILER_FORCE_CALCULATION(nomatch);
+
+ /*
+ * After the above funky bitwise arithmetic we have 'nomatch == 0' iff
+ * src_ip == existing_rule->src_ip &&
+ * dst_ip == existing_rule->dst_ip &&
+ * skb->ifindex == existing_rule->ifindex &&
+ * sport == existing_rule->src_port &&
+ * dport == existing_rule->dst_port &&
+ * protocol == existing_rule->proto
+ */
+
+ if (!nomatch) {
+ if (existing_rule->dscp_val < 0) return; // cached no-op
+
if (ipv4) {
uint8_t newTos = UPDATE_TOS(existing_rule->dscp_val, tos);
bpf_l3_csum_replace(skb, l2_header_size + IP4_OFFSET(check), htons(tos), htons(newTos),
@@ -132,7 +148,7 @@
bpf_skb_store_bytes(skb, l2_header_size, &new_first_be32, sizeof(__be32),
BPF_F_RECOMPUTE_CSUM);
}
- return;
+ return; // cached DSCP mutation
}
// Linear scan ipv4_dscp_policies_map since no stored params match skb.
@@ -187,7 +203,8 @@
}
}
- RuleEntry value = {
+ // Update cache with found policy.
+ *existing_rule = (RuleEntry){
.src_ip = src_ip,
.dst_ip = dst_ip,
.ifindex = skb->ifindex,
@@ -197,9 +214,6 @@
.dscp_val = new_dscp,
};
- // Update cache with found policy.
- bpf_socket_policy_cache_map_update_elem(&cacheid, &value, BPF_ANY);
-
if (new_dscp < 0) return;
// Need to store bytes after updating map or program will not load.
diff --git a/bpf/progs/dscpPolicy.h b/bpf/progs/dscpPolicy.h
index 6a6b711..413fb0f 100644
--- a/bpf/progs/dscpPolicy.h
+++ b/bpf/progs/dscpPolicy.h
@@ -28,9 +28,6 @@
#define v6_not_equal(a, b) ((v6_hi_be64(a) ^ v6_hi_be64(b)) \
| (v6_lo_be64(a) ^ v6_lo_be64(b)))
-// Returns 'a == b' as boolean
-#define v6_equal(a, b) (!v6_not_equal((a), (b)))
-
typedef struct {
struct in6_addr src_ip;
struct in6_addr dst_ip;
diff --git a/service/src/com/android/server/connectivity/DscpPolicyValue.java b/service/src/com/android/server/connectivity/DscpPolicyValue.java
index 7162a4a..a9100ac 100644
--- a/service/src/com/android/server/connectivity/DscpPolicyValue.java
+++ b/service/src/com/android/server/connectivity/DscpPolicyValue.java
@@ -117,8 +117,8 @@
this.proto = proto != -1 ? proto : 0;
this.dscp = dscp;
- this.match_src_ip = (this.src46 != EMPTY_ADDRESS_FIELD);
- this.match_dst_ip = (this.dst46 != EMPTY_ADDRESS_FIELD);
+ this.match_src_ip = (src46 != null);
+ this.match_dst_ip = (dst46 != null);
this.match_src_port = (srcPort != -1);
this.match_proto = (proto != -1);
}
diff --git a/staticlibs/framework/com/android/net/module/util/LocationPermissionChecker.java b/staticlibs/framework/com/android/net/module/util/LocationPermissionChecker.java
index 28c33f3..e4d25cd 100644
--- a/staticlibs/framework/com/android/net/module/util/LocationPermissionChecker.java
+++ b/staticlibs/framework/com/android/net/module/util/LocationPermissionChecker.java
@@ -117,7 +117,11 @@
@VisibleForTesting(visibility = PRIVATE)
public @LocationPermissionCheckStatus int checkLocationPermissionInternal(
String pkgName, @Nullable String featureId, int uid, @Nullable String message) {
- checkPackage(uid, pkgName);
+ try {
+ checkPackage(uid, pkgName);
+ } catch (SecurityException e) {
+ return ERROR_LOCATION_PERMISSION_MISSING;
+ }
// Apps with NETWORK_SETTINGS, NETWORK_SETUP_WIZARD, NETWORK_STACK & MAINLINE_NETWORK_STACK
// are granted a bypass.
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/LocationPermissionCheckerTest.java b/staticlibs/tests/unit/src/com/android/net/module/util/LocationPermissionCheckerTest.java
index c8f8656..d773374 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/LocationPermissionCheckerTest.java
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/LocationPermissionCheckerTest.java
@@ -18,7 +18,6 @@
import static android.Manifest.permission.NETWORK_SETTINGS;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
@@ -47,7 +46,6 @@
import com.android.testutils.DevSdkIgnoreRule;
-import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@@ -242,9 +240,9 @@
mWifiScanAllowApps = AppOpsManager.MODE_ALLOWED;
setupTestCase();
- assertThrows(SecurityException.class,
- () -> mChecker.checkLocationPermissionInternal(
- TEST_PKG_NAME, TEST_FEATURE_ID, mUid, null));
+ final int result = mChecker.checkLocationPermissionInternal(
+ TEST_PKG_NAME, TEST_FEATURE_ID, mUid, null);
+ assertEquals(LocationPermissionChecker.ERROR_LOCATION_PERMISSION_MISSING, result);
}
@Test
@@ -305,14 +303,4 @@
TEST_PKG_NAME, TEST_FEATURE_ID, mUid, null);
assertEquals(LocationPermissionChecker.SUCCEEDED, result);
}
-
-
- private static void assertThrows(Class<? extends Exception> exceptionClass, Runnable r) {
- try {
- r.run();
- Assert.fail("Expected " + exceptionClass + " to be thrown.");
- } catch (Exception exception) {
- assertTrue(exceptionClass.isInstance(exception));
- }
- }
}
diff --git a/tests/cts/net/src/android/net/cts/DscpPolicyTest.kt b/tests/cts/net/src/android/net/cts/DscpPolicyTest.kt
index f73134a..041e6cb 100644
--- a/tests/cts/net/src/android/net/cts/DscpPolicyTest.kt
+++ b/tests/cts/net/src/android/net/cts/DscpPolicyTest.kt
@@ -298,7 +298,8 @@
fun sendPacket(
agent: TestableNetworkAgent,
sendV6: Boolean,
- dstPort: Int = 0
+ dstPort: Int = 0,
+ times: Int = 1
) {
val testString = "test string"
val testPacket = ByteBuffer.wrap(testString.toByteArray(Charsets.UTF_8))
@@ -308,9 +309,11 @@
IPPROTO_UDP)
checkNotNull(agent.network).bindSocket(socket)
- val originalPacket = testPacket.readAsArray()
- Os.sendto(socket, originalPacket, 0 /* bytesOffset */, originalPacket.size, 0 /* flags */,
+ val origPacket = testPacket.readAsArray()
+ repeat(times) {
+ Os.sendto(socket, origPacket, 0 /* bytesOffset */, origPacket.size, 0 /* flags */,
if (sendV6) TEST_TARGET_IPV6_ADDR else TEST_TARGET_IPV4_ADDR, dstPort)
+ }
Os.close(socket)
}
@@ -400,10 +403,11 @@
agent: TestableNetworkAgent,
sendV6: Boolean = false,
dscpValue: Int = 0,
- dstPort: Int = 0
+ dstPort: Int = 0,
+ times: Int = 1
) {
- var packetFound = false
- sendPacket(agent, sendV6, dstPort)
+ var packetFound = 0
+ sendPacket(agent, sendV6, dstPort, times)
// TODO: grab source port from socket in sendPacket
Log.e(TAG, "find DSCP value:" + dscpValue)
@@ -424,10 +428,23 @@
if (parsePacketIp(buffer, sendV6) && parsePacketPort(buffer, 0, dstPort)) {
Log.e(TAG, "DSCP value found")
assertEquals(dscpValue, dscp)
- packetFound = true
+ packetFound++
}
}
- assertTrue(packetFound)
+ assertTrue(packetFound == times)
+ }
+
+ fun validatePackets(
+ agent: TestableNetworkAgent,
+ sendV6: Boolean = false,
+ dscpValue: Int = 0,
+ dstPort: Int = 0
+ ) {
+ // We send two packets from the same socket to verify
+ // socket caching works correctly.
+ validatePacket(agent, sendV6, dscpValue, dstPort, 2)
+ // Try one more time from a different socket.
+ validatePacket(agent, sendV6, dscpValue, dstPort, 1)
}
fun doRemovePolicyTest(
@@ -453,10 +470,7 @@
assertEquals(1, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
}
- validatePacket(agent, dscpValue = 1, dstPort = 4444)
- // Send a second packet to validate that the stored BPF policy
- // is correct for subsequent packets.
- validatePacket(agent, dscpValue = 1, dstPort = 4444)
+ validatePackets(agent, dscpValue = 1, dstPort = 4444)
agent.sendRemoveDscpPolicy(1)
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
@@ -475,7 +489,7 @@
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
}
- validatePacket(agent, dscpValue = 4, dstPort = 5555)
+ validatePackets(agent, dscpValue = 4, dstPort = 5555)
agent.sendRemoveDscpPolicy(1)
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
@@ -494,10 +508,7 @@
assertEquals(1, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
}
- validatePacket(agent, true, dscpValue = 1, dstPort = 4444)
- // Send a second packet to validate that the stored BPF policy
- // is correct for subsequent packets.
- validatePacket(agent, true, dscpValue = 1, dstPort = 4444)
+ validatePackets(agent, true, dscpValue = 1, dstPort = 4444)
agent.sendRemoveDscpPolicy(1)
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
@@ -515,7 +526,7 @@
assertEquals(1, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
}
- validatePacket(agent, true, dscpValue = 4, dstPort = 5555)
+ validatePackets(agent, true, dscpValue = 4, dstPort = 5555)
agent.sendRemoveDscpPolicy(1)
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
@@ -533,7 +544,7 @@
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(1, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
- validatePacket(agent, dscpValue = 1, dstPort = 1111)
+ validatePackets(agent, dscpValue = 1, dstPort = 1111)
}
val policy2 = DscpPolicy.Builder(2, 1).setDestinationPortRange(Range(2222, 2222)).build()
@@ -541,7 +552,7 @@
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(2, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
- validatePacket(agent, dscpValue = 1, dstPort = 2222)
+ validatePackets(agent, dscpValue = 1, dstPort = 2222)
}
val policy3 = DscpPolicy.Builder(3, 1).setDestinationPortRange(Range(3333, 3333)).build()
@@ -549,16 +560,16 @@
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(3, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
- validatePacket(agent, dscpValue = 1, dstPort = 3333)
+ validatePackets(agent, dscpValue = 1, dstPort = 3333)
}
/* Remove Policies and check CE is no longer set */
doRemovePolicyTest(agent, callback, 1)
- validatePacket(agent, dscpValue = 0, dstPort = 1111)
+ validatePackets(agent, dscpValue = 0, dstPort = 1111)
doRemovePolicyTest(agent, callback, 2)
- validatePacket(agent, dscpValue = 0, dstPort = 2222)
+ validatePackets(agent, dscpValue = 0, dstPort = 2222)
doRemovePolicyTest(agent, callback, 3)
- validatePacket(agent, dscpValue = 0, dstPort = 3333)
+ validatePackets(agent, dscpValue = 0, dstPort = 3333)
}
@Test
@@ -569,7 +580,7 @@
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(1, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
- validatePacket(agent, dscpValue = 1, dstPort = 1111)
+ validatePackets(agent, dscpValue = 1, dstPort = 1111)
}
doRemovePolicyTest(agent, callback, 1)
@@ -578,7 +589,7 @@
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(2, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
- validatePacket(agent, dscpValue = 1, dstPort = 2222)
+ validatePackets(agent, dscpValue = 1, dstPort = 2222)
}
doRemovePolicyTest(agent, callback, 2)
@@ -587,7 +598,7 @@
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(3, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
- validatePacket(agent, dscpValue = 1, dstPort = 3333)
+ validatePackets(agent, dscpValue = 1, dstPort = 3333)
}
doRemovePolicyTest(agent, callback, 3)
}
@@ -601,7 +612,7 @@
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(1, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
- validatePacket(agent, dscpValue = 1, dstPort = 1111)
+ validatePackets(agent, dscpValue = 1, dstPort = 1111)
}
val policy2 = DscpPolicy.Builder(2, 1).setDestinationPortRange(Range(2222, 2222)).build()
@@ -609,7 +620,7 @@
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(2, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
- validatePacket(agent, dscpValue = 1, dstPort = 2222)
+ validatePackets(agent, dscpValue = 1, dstPort = 2222)
}
val policy3 = DscpPolicy.Builder(3, 1).setDestinationPortRange(Range(3333, 3333)).build()
@@ -617,7 +628,7 @@
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(3, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
- validatePacket(agent, dscpValue = 1, dstPort = 3333)
+ validatePackets(agent, dscpValue = 1, dstPort = 3333)
}
/* Remove Policies and check CE is no longer set */
@@ -643,7 +654,7 @@
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(1, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
- validatePacket(agent, dscpValue = 1, dstPort = 1111)
+ validatePackets(agent, dscpValue = 1, dstPort = 1111)
}
val policy2 = DscpPolicy.Builder(2, 1)
@@ -652,7 +663,7 @@
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(2, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
- validatePacket(agent, dscpValue = 1, dstPort = 2222)
+ validatePackets(agent, dscpValue = 1, dstPort = 2222)
}
val policy3 = DscpPolicy.Builder(3, 1)
@@ -661,24 +672,24 @@
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(3, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
- validatePacket(agent, dscpValue = 1, dstPort = 3333)
+ validatePackets(agent, dscpValue = 1, dstPort = 3333)
}
agent.sendRemoveAllDscpPolicies()
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(1, it.policyId)
assertEquals(DSCP_POLICY_STATUS_DELETED, it.status)
- validatePacket(agent, false, dstPort = 1111)
+ validatePackets(agent, false, dstPort = 1111)
}
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(2, it.policyId)
assertEquals(DSCP_POLICY_STATUS_DELETED, it.status)
- validatePacket(agent, false, dstPort = 2222)
+ validatePackets(agent, false, dstPort = 2222)
}
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(3, it.policyId)
assertEquals(DSCP_POLICY_STATUS_DELETED, it.status)
- validatePacket(agent, false, dstPort = 3333)
+ validatePackets(agent, false, dstPort = 3333)
}
}
@@ -690,7 +701,7 @@
agent.expectCallback<OnDscpPolicyStatusUpdated>().let {
assertEquals(1, it.policyId)
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
- validatePacket(agent, dscpValue = 1, dstPort = 4444)
+ validatePackets(agent, dscpValue = 1, dstPort = 4444)
}
val policy2 = DscpPolicy.Builder(1, 1).setDestinationPortRange(Range(5555, 5555)).build()
@@ -700,8 +711,8 @@
assertEquals(DSCP_POLICY_STATUS_SUCCESS, it.status)
// Sending packet with old policy should fail
- validatePacket(agent, dscpValue = 0, dstPort = 4444)
- validatePacket(agent, dscpValue = 1, dstPort = 5555)
+ validatePackets(agent, dscpValue = 0, dstPort = 4444)
+ validatePackets(agent, dscpValue = 1, dstPort = 5555)
}
agent.sendRemoveDscpPolicy(1)
diff --git a/thread/tests/integration/Android.bp b/thread/tests/integration/Android.bp
index 71693af..59e8e19 100644
--- a/thread/tests/integration/Android.bp
+++ b/thread/tests/integration/Android.bp
@@ -58,6 +58,7 @@
],
srcs: [
"src/**/*.java",
+ "src/**/*.kt",
],
compile_multilib: "both",
}
diff --git a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
index 9e8dc3a..103282a 100644
--- a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
+++ b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
@@ -34,7 +34,6 @@
import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ECHO_REPLY_TYPE;
import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ECHO_REQUEST_TYPE;
-import static com.android.testutils.TestNetworkTrackerKt.initTestNetwork;
import static com.android.testutils.TestPermissionUtil.runAsShell;
import static com.google.common.truth.Truth.assertThat;
@@ -49,11 +48,9 @@
import android.content.Context;
import android.net.IpPrefix;
import android.net.LinkAddress;
-import android.net.LinkProperties;
-import android.net.MacAddress;
-import android.net.RouteInfo;
import android.net.thread.utils.FullThreadDevice;
import android.net.thread.utils.InfraNetworkDevice;
+import android.net.thread.utils.IntegrationTestUtils;
import android.net.thread.utils.OtDaemonController;
import android.net.thread.utils.ThreadFeatureCheckerRule;
import android.net.thread.utils.ThreadFeatureCheckerRule.RequiresIpv6MulticastRouting;
@@ -634,32 +631,16 @@
}
private void setUpInfraNetwork() throws Exception {
- LinkProperties lp = new LinkProperties();
- // NAT64 feature requires the infra network to have an IPv4 default route.
- lp.addRoute(
- new RouteInfo(
- new IpPrefix("0.0.0.0/0") /* destination */,
- null /* gateway */,
- null,
- RouteInfo.RTN_UNICAST,
- 1500 /* mtu */));
- mInfraNetworkTracker =
- runAsShell(
- MANAGE_TEST_NETWORKS,
- () -> initTestNetwork(mContext, lp, 5000 /* timeoutMs */));
- String infraNetworkName = mInfraNetworkTracker.getTestIface().getInterfaceName();
- mController.setTestNetworkAsUpstreamAndWait(infraNetworkName);
+ mInfraNetworkTracker = IntegrationTestUtils.setUpInfraNetwork(mContext, mController);
}
private void tearDownInfraNetwork() {
- runAsShell(MANAGE_TEST_NETWORKS, () -> mInfraNetworkTracker.teardown());
+ IntegrationTestUtils.tearDownInfraNetwork(mInfraNetworkTracker);
}
- private void startInfraDeviceAndWaitForOnLinkAddr() throws Exception {
+ private void startInfraDeviceAndWaitForOnLinkAddr() {
mInfraDevice =
- new InfraNetworkDevice(MacAddress.fromString("1:2:3:4:5:6"), mInfraNetworkReader);
- mInfraDevice.runSlaac(Duration.ofSeconds(60));
- assertNotNull(mInfraDevice.ipv6Addr);
+ IntegrationTestUtils.startInfraDeviceAndWaitForOnLinkAddr(mInfraNetworkReader);
}
private void assertInfraLinkMemberOfGroup(Inet6Address address) throws Exception {
diff --git a/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.java b/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.java
deleted file mode 100644
index 82e9332..0000000
--- a/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.java
+++ /dev/null
@@ -1,563 +0,0 @@
-/*
- * Copyright (C) 2023 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package android.net.thread.utils;
-
-import static android.net.NetworkCapabilities.NET_CAPABILITY_LOCAL_NETWORK;
-import static android.system.OsConstants.IPPROTO_ICMP;
-import static android.system.OsConstants.IPPROTO_ICMPV6;
-
-import static com.android.compatibility.common.util.SystemUtil.runShellCommandOrThrow;
-import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ND_OPTION_PIO;
-import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ROUTER_ADVERTISEMENT;
-
-import static com.google.common.io.BaseEncoding.base16;
-import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
-
-import static org.junit.Assert.assertNotNull;
-
-import static java.util.concurrent.TimeUnit.MILLISECONDS;
-import static java.util.concurrent.TimeUnit.SECONDS;
-
-import android.net.ConnectivityManager;
-import android.net.InetAddresses;
-import android.net.LinkAddress;
-import android.net.Network;
-import android.net.NetworkCapabilities;
-import android.net.NetworkRequest;
-import android.net.TestNetworkInterface;
-import android.net.nsd.NsdManager;
-import android.net.nsd.NsdServiceInfo;
-import android.net.thread.ActiveOperationalDataset;
-import android.net.thread.ThreadNetworkController;
-import android.os.Build;
-import android.os.Handler;
-import android.os.SystemClock;
-
-import androidx.annotation.NonNull;
-import androidx.test.core.app.ApplicationProvider;
-
-import com.android.net.module.util.Struct;
-import com.android.net.module.util.structs.Icmpv4Header;
-import com.android.net.module.util.structs.Icmpv6Header;
-import com.android.net.module.util.structs.Ipv4Header;
-import com.android.net.module.util.structs.Ipv6Header;
-import com.android.net.module.util.structs.PrefixInformationOption;
-import com.android.net.module.util.structs.RaHeader;
-import com.android.testutils.HandlerUtils;
-import com.android.testutils.TapPacketReader;
-
-import com.google.common.util.concurrent.SettableFuture;
-
-import java.io.FileDescriptor;
-import java.io.IOException;
-import java.net.DatagramPacket;
-import java.net.DatagramSocket;
-import java.net.Inet4Address;
-import java.net.Inet6Address;
-import java.net.InetAddress;
-import java.net.InetSocketAddress;
-import java.net.SocketAddress;
-import java.nio.ByteBuffer;
-import java.time.Duration;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
-import java.util.function.Predicate;
-import java.util.function.Supplier;
-
-/** Static utility methods relating to Thread integration tests. */
-public final class IntegrationTestUtils {
- // The timeout of join() after restarting ot-daemon. The device needs to send 6 Link Request
- // every 5 seconds, followed by 4 Parent Request every second. So this value needs to be 40
- // seconds to be safe
- public static final Duration RESTART_JOIN_TIMEOUT = Duration.ofSeconds(40);
- public static final Duration JOIN_TIMEOUT = Duration.ofSeconds(30);
- public static final Duration LEAVE_TIMEOUT = Duration.ofSeconds(2);
- public static final Duration CALLBACK_TIMEOUT = Duration.ofSeconds(1);
- public static final Duration SERVICE_DISCOVERY_TIMEOUT = Duration.ofSeconds(20);
-
- // A valid Thread Active Operational Dataset generated from OpenThread CLI "dataset init new".
- private static final byte[] DEFAULT_DATASET_TLVS =
- base16().decode(
- "0E080000000000010000000300001335060004001FFFE002"
- + "08ACC214689BC40BDF0708FD64DB1225F47E0B0510F26B31"
- + "53760F519A63BAFDDFFC80D2AF030F4F70656E5468726561"
- + "642D643961300102D9A00410A245479C836D551B9CA557F7"
- + "B9D351B40C0402A0FFF8");
- public static final ActiveOperationalDataset DEFAULT_DATASET =
- ActiveOperationalDataset.fromThreadTlvs(DEFAULT_DATASET_TLVS);
-
- private IntegrationTestUtils() {}
-
- /**
- * Waits for the given {@link Supplier} to be true until given timeout.
- *
- * @param condition the condition to check
- * @param timeout the time to wait for the condition before throwing
- * @throws TimeoutException if the condition is still not met when the timeout expires
- */
- public static void waitFor(Supplier<Boolean> condition, Duration timeout)
- throws TimeoutException {
- final long intervalMills = 500;
- final long timeoutMills = timeout.toMillis();
-
- for (long i = 0; i < timeoutMills; i += intervalMills) {
- if (condition.get()) {
- return;
- }
- SystemClock.sleep(intervalMills);
- }
- if (condition.get()) {
- return;
- }
- throw new TimeoutException("The condition failed to become true in " + timeout);
- }
-
- /**
- * Creates a {@link TapPacketReader} given the {@link TestNetworkInterface} and {@link Handler}.
- *
- * @param testNetworkInterface the TUN interface of the test network
- * @param handler the handler to process the packets
- * @return the {@link TapPacketReader}
- */
- public static TapPacketReader newPacketReader(
- TestNetworkInterface testNetworkInterface, Handler handler) {
- FileDescriptor fd = testNetworkInterface.getFileDescriptor().getFileDescriptor();
- final TapPacketReader reader =
- new TapPacketReader(handler, fd, testNetworkInterface.getMtu());
- handler.post(() -> reader.start());
- HandlerUtils.waitForIdle(handler, 5000 /* timeout in milliseconds */);
- return reader;
- }
-
- /**
- * Waits for the Thread module to enter any state of the given {@code deviceRoles}.
- *
- * @param controller the {@link ThreadNetworkController}
- * @param deviceRoles the desired device roles. See also {@link
- * ThreadNetworkController.DeviceRole}
- * @param timeout the time to wait for the expected state before throwing
- * @return the {@link ThreadNetworkController.DeviceRole} after waiting
- * @throws TimeoutException if the device hasn't become any of expected roles until the timeout
- * expires
- */
- public static int waitForStateAnyOf(
- ThreadNetworkController controller, List<Integer> deviceRoles, Duration timeout)
- throws TimeoutException {
- SettableFuture<Integer> future = SettableFuture.create();
- ThreadNetworkController.StateCallback callback =
- newRole -> {
- if (deviceRoles.contains(newRole)) {
- future.set(newRole);
- }
- };
- controller.registerStateCallback(directExecutor(), callback);
- try {
- return future.get(timeout.toMillis(), TimeUnit.MILLISECONDS);
- } catch (InterruptedException | ExecutionException e) {
- throw new TimeoutException(
- String.format(
- "The device didn't become an expected role in %s: %s",
- timeout, e.getMessage()));
- } finally {
- controller.unregisterStateCallback(callback);
- }
- }
-
- /**
- * Polls for a packet from a given {@link TapPacketReader} that satisfies the {@code filter}.
- *
- * @param packetReader a TUN packet reader
- * @param filter the filter to be applied on the packet
- * @return the first IPv6 packet that satisfies the {@code filter}. If it has waited for more
- * than 3000ms to read the next packet, the method will return null
- */
- public static byte[] pollForPacket(TapPacketReader packetReader, Predicate<byte[]> filter) {
- byte[] packet;
- while ((packet = packetReader.poll(3000 /* timeoutMs */, filter)) != null) {
- return packet;
- }
- return null;
- }
-
- /** Returns {@code true} if {@code packet} is an ICMPv4 packet of given {@code type}. */
- public static boolean isExpectedIcmpv4Packet(byte[] packet, int type) {
- ByteBuffer buf = makeByteBuffer(packet);
- Ipv4Header header = extractIpv4Header(buf);
- if (header == null) {
- return false;
- }
- if (header.protocol != (byte) IPPROTO_ICMP) {
- return false;
- }
- try {
- return Struct.parse(Icmpv4Header.class, buf).type == (short) type;
- } catch (IllegalArgumentException ignored) {
- // It's fine that the passed in packet is malformed because it's could be sent
- // by anybody.
- }
- return false;
- }
-
- /** Returns {@code true} if {@code packet} is an ICMPv6 packet of given {@code type}. */
- public static boolean isExpectedIcmpv6Packet(byte[] packet, int type) {
- ByteBuffer buf = makeByteBuffer(packet);
- Ipv6Header header = extractIpv6Header(buf);
- if (header == null) {
- return false;
- }
- if (header.nextHeader != (byte) IPPROTO_ICMPV6) {
- return false;
- }
- try {
- return Struct.parse(Icmpv6Header.class, buf).type == (short) type;
- } catch (IllegalArgumentException ignored) {
- // It's fine that the passed in packet is malformed because it's could be sent
- // by anybody.
- }
- return false;
- }
-
- public static boolean isFrom(byte[] packet, InetAddress src) {
- if (src instanceof Inet4Address) {
- return isFromIpv4Source(packet, (Inet4Address) src);
- } else if (src instanceof Inet6Address) {
- return isFromIpv6Source(packet, (Inet6Address) src);
- }
- return false;
- }
-
- public static boolean isTo(byte[] packet, InetAddress dest) {
- if (dest instanceof Inet4Address) {
- return isToIpv4Destination(packet, (Inet4Address) dest);
- } else if (dest instanceof Inet6Address) {
- return isToIpv6Destination(packet, (Inet6Address) dest);
- }
- return false;
- }
-
- private static boolean isFromIpv4Source(byte[] packet, Inet4Address src) {
- Ipv4Header header = extractIpv4Header(makeByteBuffer(packet));
- return header != null && header.srcIp.equals(src);
- }
-
- private static boolean isFromIpv6Source(byte[] packet, Inet6Address src) {
- Ipv6Header header = extractIpv6Header(makeByteBuffer(packet));
- return header != null && header.srcIp.equals(src);
- }
-
- private static boolean isToIpv4Destination(byte[] packet, Inet4Address dest) {
- Ipv4Header header = extractIpv4Header(makeByteBuffer(packet));
- return header != null && header.dstIp.equals(dest);
- }
-
- private static boolean isToIpv6Destination(byte[] packet, Inet6Address dest) {
- Ipv6Header header = extractIpv6Header(makeByteBuffer(packet));
- return header != null && header.dstIp.equals(dest);
- }
-
- private static ByteBuffer makeByteBuffer(byte[] packet) {
- return packet == null ? null : ByteBuffer.wrap(packet);
- }
-
- private static Ipv4Header extractIpv4Header(ByteBuffer buf) {
- try {
- return Struct.parse(Ipv4Header.class, buf);
- } catch (IllegalArgumentException ignored) {
- // It's fine that the passed in packet is malformed because it's could be sent
- // by anybody.
- }
- return null;
- }
-
- private static Ipv6Header extractIpv6Header(ByteBuffer buf) {
- try {
- return Struct.parse(Ipv6Header.class, buf);
- } catch (IllegalArgumentException ignored) {
- // It's fine that the passed in packet is malformed because it's could be sent
- // by anybody.
- }
- return null;
- }
-
- /** Returns the Prefix Information Options (PIO) extracted from an ICMPv6 RA message. */
- public static List<PrefixInformationOption> getRaPios(byte[] raMsg) {
- final ArrayList<PrefixInformationOption> pioList = new ArrayList<>();
-
- if (raMsg == null) {
- return pioList;
- }
-
- final ByteBuffer buf = ByteBuffer.wrap(raMsg);
- final Ipv6Header ipv6Header = Struct.parse(Ipv6Header.class, buf);
- if (ipv6Header.nextHeader != (byte) IPPROTO_ICMPV6) {
- return pioList;
- }
-
- final Icmpv6Header icmpv6Header = Struct.parse(Icmpv6Header.class, buf);
- if (icmpv6Header.type != (short) ICMPV6_ROUTER_ADVERTISEMENT) {
- return pioList;
- }
-
- Struct.parse(RaHeader.class, buf);
- while (buf.position() < raMsg.length) {
- final int currentPos = buf.position();
- final int type = Byte.toUnsignedInt(buf.get());
- final int length = Byte.toUnsignedInt(buf.get());
- if (type == ICMPV6_ND_OPTION_PIO) {
- final ByteBuffer pioBuf =
- ByteBuffer.wrap(
- buf.array(),
- currentPos,
- Struct.getSize(PrefixInformationOption.class));
- final PrefixInformationOption pio =
- Struct.parse(PrefixInformationOption.class, pioBuf);
- pioList.add(pio);
-
- // Move ByteBuffer position to the next option.
- buf.position(currentPos + Struct.getSize(PrefixInformationOption.class));
- } else {
- // The length is in units of 8 octets.
- buf.position(currentPos + (length * 8));
- }
- }
- return pioList;
- }
-
- /**
- * Sends a UDP message to a destination.
- *
- * @param dstAddress the IP address of the destination
- * @param dstPort the port of the destination
- * @param message the message in UDP payload
- * @throws IOException if failed to send the message
- */
- public static void sendUdpMessage(InetAddress dstAddress, int dstPort, String message)
- throws IOException {
- SocketAddress dstSockAddr = new InetSocketAddress(dstAddress, dstPort);
-
- try (DatagramSocket socket = new DatagramSocket()) {
- socket.connect(dstSockAddr);
-
- byte[] msgBytes = message.getBytes();
- DatagramPacket packet = new DatagramPacket(msgBytes, msgBytes.length);
-
- socket.send(packet);
- }
- }
-
- public static boolean isInMulticastGroup(String interfaceName, Inet6Address address) {
- final String cmd = "ip -6 maddr show dev " + interfaceName;
- final String output = runShellCommandOrThrow(cmd);
- final String addressStr = address.getHostAddress();
- for (final String line : output.split("\\n")) {
- if (line.contains(addressStr)) {
- return true;
- }
- }
- return false;
- }
-
- public static List<LinkAddress> getIpv6LinkAddresses(String interfaceName) {
- List<LinkAddress> addresses = new ArrayList<>();
- final String cmd = " ip -6 addr show dev " + interfaceName;
- final String output = runShellCommandOrThrow(cmd);
-
- for (final String line : output.split("\\n")) {
- if (line.contains("inet6")) {
- addresses.add(parseAddressLine(line));
- }
- }
-
- return addresses;
- }
-
- /** Return the first discovered service of {@code serviceType}. */
- public static NsdServiceInfo discoverService(NsdManager nsdManager, String serviceType)
- throws Exception {
- CompletableFuture<NsdServiceInfo> serviceInfoFuture = new CompletableFuture<>();
- NsdManager.DiscoveryListener listener =
- new DefaultDiscoveryListener() {
- @Override
- public void onServiceFound(NsdServiceInfo serviceInfo) {
- serviceInfoFuture.complete(serviceInfo);
- }
- };
- nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener);
- try {
- serviceInfoFuture.get(SERVICE_DISCOVERY_TIMEOUT.toMillis(), MILLISECONDS);
- } finally {
- nsdManager.stopServiceDiscovery(listener);
- }
-
- return serviceInfoFuture.get();
- }
-
- /**
- * Returns the {@link NsdServiceInfo} when a service instance of {@code serviceType} gets lost.
- */
- public static NsdManager.DiscoveryListener discoverForServiceLost(
- NsdManager nsdManager,
- String serviceType,
- CompletableFuture<NsdServiceInfo> serviceInfoFuture) {
- NsdManager.DiscoveryListener listener =
- new DefaultDiscoveryListener() {
- @Override
- public void onServiceLost(NsdServiceInfo serviceInfo) {
- serviceInfoFuture.complete(serviceInfo);
- }
- };
- nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener);
- return listener;
- }
-
- /** Resolves the service. */
- public static NsdServiceInfo resolveService(NsdManager nsdManager, NsdServiceInfo serviceInfo)
- throws Exception {
- return resolveServiceUntil(nsdManager, serviceInfo, s -> true);
- }
-
- /** Returns the first resolved service that satisfies the {@code predicate}. */
- public static NsdServiceInfo resolveServiceUntil(
- NsdManager nsdManager, NsdServiceInfo serviceInfo, Predicate<NsdServiceInfo> predicate)
- throws Exception {
- CompletableFuture<NsdServiceInfo> resolvedServiceInfoFuture = new CompletableFuture<>();
- NsdManager.ServiceInfoCallback callback =
- new DefaultServiceInfoCallback() {
- @Override
- public void onServiceUpdated(@NonNull NsdServiceInfo serviceInfo) {
- if (predicate.test(serviceInfo)) {
- resolvedServiceInfoFuture.complete(serviceInfo);
- }
- }
- };
- nsdManager.registerServiceInfoCallback(serviceInfo, directExecutor(), callback);
- try {
- return resolvedServiceInfoFuture.get(
- SERVICE_DISCOVERY_TIMEOUT.toMillis(), MILLISECONDS);
- } finally {
- nsdManager.unregisterServiceInfoCallback(callback);
- }
- }
-
- public static String getPrefixesFromNetData(String netData) {
- int startIdx = netData.indexOf("Prefixes:");
- int endIdx = netData.indexOf("Routes:");
- return netData.substring(startIdx, endIdx);
- }
-
- public static Network getThreadNetwork(Duration timeout) throws Exception {
- CompletableFuture<Network> networkFuture = new CompletableFuture<>();
- ConnectivityManager cm =
- ApplicationProvider.getApplicationContext()
- .getSystemService(ConnectivityManager.class);
- NetworkRequest.Builder networkRequestBuilder =
- new NetworkRequest.Builder().addTransportType(NetworkCapabilities.TRANSPORT_THREAD);
- // Before V, we need to explicitly set `NET_CAPABILITY_LOCAL_NETWORK` capability to request
- // a Thread network.
- if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
- networkRequestBuilder.addCapability(NET_CAPABILITY_LOCAL_NETWORK);
- }
- NetworkRequest networkRequest = networkRequestBuilder.build();
- ConnectivityManager.NetworkCallback networkCallback =
- new ConnectivityManager.NetworkCallback() {
- @Override
- public void onAvailable(Network network) {
- networkFuture.complete(network);
- }
- };
- cm.registerNetworkCallback(networkRequest, networkCallback);
- return networkFuture.get(timeout.toSeconds(), SECONDS);
- }
-
- /**
- * Let the FTD join the specified Thread network and wait for border routing to be available.
- *
- * @return the OMR address
- */
- public static Inet6Address joinNetworkAndWaitForOmr(
- FullThreadDevice ftd, ActiveOperationalDataset dataset) throws Exception {
- ftd.factoryReset();
- ftd.joinNetwork(dataset);
- ftd.waitForStateAnyOf(List.of("router", "child"), JOIN_TIMEOUT);
- waitFor(() -> ftd.getOmrAddress() != null, Duration.ofSeconds(60));
- Inet6Address ftdOmr = ftd.getOmrAddress();
- assertNotNull(ftdOmr);
- return ftdOmr;
- }
-
- private static class DefaultDiscoveryListener implements NsdManager.DiscoveryListener {
- @Override
- public void onStartDiscoveryFailed(String serviceType, int errorCode) {}
-
- @Override
- public void onStopDiscoveryFailed(String serviceType, int errorCode) {}
-
- @Override
- public void onDiscoveryStarted(String serviceType) {}
-
- @Override
- public void onDiscoveryStopped(String serviceType) {}
-
- @Override
- public void onServiceFound(NsdServiceInfo serviceInfo) {}
-
- @Override
- public void onServiceLost(NsdServiceInfo serviceInfo) {}
- }
-
- private static class DefaultServiceInfoCallback implements NsdManager.ServiceInfoCallback {
- @Override
- public void onServiceInfoCallbackRegistrationFailed(int errorCode) {}
-
- @Override
- public void onServiceUpdated(@NonNull NsdServiceInfo serviceInfo) {}
-
- @Override
- public void onServiceLost() {}
-
- @Override
- public void onServiceInfoCallbackUnregistered() {}
- }
-
- /**
- * Parses a line of output from "ip -6 addr show" into a {@link LinkAddress}.
- *
- * <p>Example line: "inet6 2001:db8:1:1::1/64 scope global deprecated"
- */
- private static LinkAddress parseAddressLine(String line) {
- String[] parts = line.trim().split("\\s+");
- String addressString = parts[1];
- String[] pieces = addressString.split("/", 2);
- int prefixLength = Integer.parseInt(pieces[1]);
- final InetAddress address = InetAddresses.parseNumericAddress(pieces[0]);
- long deprecationTimeMillis =
- line.contains("deprecated")
- ? SystemClock.elapsedRealtime()
- : LinkAddress.LIFETIME_PERMANENT;
-
- return new LinkAddress(
- address,
- prefixLength,
- 0 /* flags */,
- 0 /* scope */,
- deprecationTimeMillis,
- LinkAddress.LIFETIME_PERMANENT /* expirationTime */);
- }
-}
diff --git a/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.kt b/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.kt
new file mode 100644
index 0000000..fa9855e
--- /dev/null
+++ b/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.kt
@@ -0,0 +1,598 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net.thread.utils
+
+import android.Manifest.permission.MANAGE_TEST_NETWORKS
+import android.content.Context
+import android.net.ConnectivityManager
+import android.net.InetAddresses.parseNumericAddress
+import android.net.IpPrefix
+import android.net.LinkAddress
+import android.net.LinkProperties
+import android.net.MacAddress
+import android.net.Network
+import android.net.NetworkCapabilities
+import android.net.NetworkRequest
+import android.net.RouteInfo
+import android.net.TestNetworkInterface
+import android.net.nsd.NsdManager
+import android.net.nsd.NsdServiceInfo
+import android.net.thread.ActiveOperationalDataset
+import android.net.thread.ThreadNetworkController
+import android.os.Build
+import android.os.Handler
+import android.os.SystemClock
+import android.system.OsConstants
+import androidx.test.core.app.ApplicationProvider
+import com.android.compatibility.common.util.SystemUtil.runShellCommandOrThrow
+import com.android.net.module.util.NetworkStackConstants
+import com.android.net.module.util.Struct
+import com.android.net.module.util.structs.Icmpv4Header
+import com.android.net.module.util.structs.Icmpv6Header
+import com.android.net.module.util.structs.Ipv4Header
+import com.android.net.module.util.structs.Ipv6Header
+import com.android.net.module.util.structs.PrefixInformationOption
+import com.android.net.module.util.structs.RaHeader
+import com.android.testutils.TapPacketReader
+import com.android.testutils.TestNetworkTracker
+import com.android.testutils.initTestNetwork
+import com.android.testutils.runAsShell
+import com.android.testutils.waitForIdle
+import com.google.common.io.BaseEncoding
+import com.google.common.util.concurrent.MoreExecutors
+import com.google.common.util.concurrent.MoreExecutors.directExecutor
+import com.google.common.util.concurrent.SettableFuture
+import java.io.IOException
+import java.lang.Byte.toUnsignedInt
+import java.net.DatagramPacket
+import java.net.DatagramSocket
+import java.net.Inet4Address
+import java.net.Inet6Address
+import java.net.InetAddress
+import java.net.InetSocketAddress
+import java.net.SocketAddress
+import java.nio.ByteBuffer
+import java.time.Duration
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.ExecutionException
+import java.util.concurrent.TimeUnit
+import java.util.concurrent.TimeoutException
+import java.util.function.Predicate
+import java.util.function.Supplier
+import org.junit.Assert
+
+/** Utilities for Thread integration tests. */
+object IntegrationTestUtils {
+ // The timeout of join() after restarting ot-daemon. The device needs to send 6 Link Request
+ // every 5 seconds, followed by 4 Parent Request every second. So this value needs to be 40
+ // seconds to be safe
+ @JvmField
+ val RESTART_JOIN_TIMEOUT: Duration = Duration.ofSeconds(40)
+
+ @JvmField
+ val JOIN_TIMEOUT: Duration = Duration.ofSeconds(30)
+
+ @JvmField
+ val LEAVE_TIMEOUT: Duration = Duration.ofSeconds(2)
+
+ @JvmField
+ val CALLBACK_TIMEOUT: Duration = Duration.ofSeconds(1)
+
+ @JvmField
+ val SERVICE_DISCOVERY_TIMEOUT: Duration = Duration.ofSeconds(20)
+
+ // A valid Thread Active Operational Dataset generated from OpenThread CLI "dataset init new".
+ private val DEFAULT_DATASET_TLVS: ByteArray = BaseEncoding.base16().decode(
+ ("0E080000000000010000000300001335060004001FFFE002"
+ + "08ACC214689BC40BDF0708FD64DB1225F47E0B0510F26B31"
+ + "53760F519A63BAFDDFFC80D2AF030F4F70656E5468726561"
+ + "642D643961300102D9A00410A245479C836D551B9CA557F7"
+ + "B9D351B40C0402A0FFF8")
+ )
+
+ @JvmField
+ val DEFAULT_DATASET: ActiveOperationalDataset =
+ ActiveOperationalDataset.fromThreadTlvs(DEFAULT_DATASET_TLVS)
+
+ /**
+ * Waits for the given [Supplier] to be true until given timeout.
+ *
+ * @param condition the condition to check
+ * @param timeout the time to wait for the condition before throwing
+ * @throws TimeoutException if the condition is still not met when the timeout expires
+ */
+ @JvmStatic
+ @Throws(TimeoutException::class)
+ fun waitFor(condition: Supplier<Boolean>, timeout: Duration) {
+ val intervalMills: Long = 500
+ val timeoutMills = timeout.toMillis()
+
+ var i: Long = 0
+ while (i < timeoutMills) {
+ if (condition.get()) {
+ return
+ }
+ SystemClock.sleep(intervalMills)
+ i += intervalMills
+ }
+ if (condition.get()) {
+ return
+ }
+ throw TimeoutException("The condition failed to become true in $timeout")
+ }
+
+ /**
+ * Creates a [TapPacketReader] given the [TestNetworkInterface] and [Handler].
+ *
+ * @param testNetworkInterface the TUN interface of the test network
+ * @param handler the handler to process the packets
+ * @return the [TapPacketReader]
+ */
+ @JvmStatic
+ fun newPacketReader(
+ testNetworkInterface: TestNetworkInterface, handler: Handler
+ ): TapPacketReader {
+ val fd = testNetworkInterface.fileDescriptor.fileDescriptor
+ val reader = TapPacketReader(handler, fd, testNetworkInterface.mtu)
+ handler.post { reader.start() }
+ handler.waitForIdle(timeoutMs = 5000)
+ return reader
+ }
+
+ /**
+ * Waits for the Thread module to enter any state of the given `deviceRoles`.
+ *
+ * @param controller the [ThreadNetworkController]
+ * @param deviceRoles the desired device roles. See also [ ]
+ * @param timeout the time to wait for the expected state before throwing
+ * @return the [ThreadNetworkController.DeviceRole] after waiting
+ * @throws TimeoutException if the device hasn't become any of expected roles until the timeout
+ * expires
+ */
+ @JvmStatic
+ @Throws(TimeoutException::class)
+ fun waitForStateAnyOf(
+ controller: ThreadNetworkController, deviceRoles: List<Int>, timeout: Duration
+ ): Int {
+ val future = SettableFuture.create<Int>()
+ val callback = ThreadNetworkController.StateCallback { newRole: Int ->
+ if (deviceRoles.contains(newRole)) {
+ future.set(newRole)
+ }
+ }
+ controller.registerStateCallback(MoreExecutors.directExecutor(), callback)
+ try {
+ return future[timeout.toMillis(), TimeUnit.MILLISECONDS]
+ } catch (e: InterruptedException) {
+ throw TimeoutException(
+ "The device didn't become an expected role in $timeout: $e.message"
+ )
+ } catch (e: ExecutionException) {
+ throw TimeoutException(
+ "The device didn't become an expected role in $timeout: $e.message"
+ )
+ } finally {
+ controller.unregisterStateCallback(callback)
+ }
+ }
+
+ /**
+ * Polls for a packet from a given [TapPacketReader] that satisfies the `filter`.
+ *
+ * @param packetReader a TUN packet reader
+ * @param filter the filter to be applied on the packet
+ * @return the first IPv6 packet that satisfies the `filter`. If it has waited for more
+ * than 3000ms to read the next packet, the method will return null
+ */
+ @JvmStatic
+ fun pollForPacket(packetReader: TapPacketReader, filter: Predicate<ByteArray>): ByteArray? {
+ var packet: ByteArray?
+ while ((packetReader.poll(3000 /* timeoutMs */, filter).also { packet = it }) != null) {
+ return packet
+ }
+ return null
+ }
+
+ /** Returns `true` if `packet` is an ICMPv4 packet of given `type`. */
+ @JvmStatic
+ fun isExpectedIcmpv4Packet(packet: ByteArray, type: Int): Boolean {
+ val buf = makeByteBuffer(packet)
+ val header = extractIpv4Header(buf) ?: return false
+ if (header.protocol != OsConstants.IPPROTO_ICMP.toByte()) {
+ return false
+ }
+ try {
+ return Struct.parse(Icmpv4Header::class.java, buf).type == type.toShort()
+ } catch (ignored: IllegalArgumentException) {
+ // It's fine that the passed in packet is malformed because it's could be sent
+ // by anybody.
+ }
+ return false
+ }
+
+ /** Returns `true` if `packet` is an ICMPv6 packet of given `type`. */
+ @JvmStatic
+ fun isExpectedIcmpv6Packet(packet: ByteArray, type: Int): Boolean {
+ val buf = makeByteBuffer(packet)
+ val header = extractIpv6Header(buf) ?: return false
+ if (header.nextHeader != OsConstants.IPPROTO_ICMPV6.toByte()) {
+ return false
+ }
+ try {
+ return Struct.parse(Icmpv6Header::class.java, buf).type == type.toShort()
+ } catch (ignored: IllegalArgumentException) {
+ // It's fine that the passed in packet is malformed because it's could be sent
+ // by anybody.
+ }
+ return false
+ }
+
+ @JvmStatic
+ fun isFrom(packet: ByteArray, src: InetAddress): Boolean {
+ when (src) {
+ is Inet4Address -> return isFromIpv4Source(packet, src)
+ is Inet6Address -> return isFromIpv6Source(packet, src)
+ else -> return false
+ }
+ }
+
+ @JvmStatic
+ fun isTo(packet: ByteArray, dest: InetAddress): Boolean {
+ when (dest) {
+ is Inet4Address -> return isToIpv4Destination(packet, dest)
+ is Inet6Address -> return isToIpv6Destination(packet, dest)
+ else -> return false
+ }
+ }
+
+ private fun isFromIpv4Source(packet: ByteArray, src: Inet4Address): Boolean {
+ val header = extractIpv4Header(makeByteBuffer(packet))
+ return header?.srcIp == src
+ }
+
+ private fun isFromIpv6Source(packet: ByteArray, src: Inet6Address): Boolean {
+ val header = extractIpv6Header(makeByteBuffer(packet))
+ return header?.srcIp == src
+ }
+
+ private fun isToIpv4Destination(packet: ByteArray, dest: Inet4Address): Boolean {
+ val header = extractIpv4Header(makeByteBuffer(packet))
+ return header?.dstIp == dest
+ }
+
+ private fun isToIpv6Destination(packet: ByteArray, dest: Inet6Address): Boolean {
+ val header = extractIpv6Header(makeByteBuffer(packet))
+ return header?.dstIp == dest
+ }
+
+ private fun makeByteBuffer(packet: ByteArray): ByteBuffer {
+ return ByteBuffer.wrap(packet)
+ }
+
+ private fun extractIpv4Header(buf: ByteBuffer): Ipv4Header? {
+ try {
+ return Struct.parse(Ipv4Header::class.java, buf)
+ } catch (ignored: IllegalArgumentException) {
+ // It's fine that the passed in packet is malformed because it's could be sent
+ // by anybody.
+ }
+ return null
+ }
+
+ private fun extractIpv6Header(buf: ByteBuffer): Ipv6Header? {
+ try {
+ return Struct.parse(Ipv6Header::class.java, buf)
+ } catch (ignored: IllegalArgumentException) {
+ // It's fine that the passed in packet is malformed because it's could be sent
+ // by anybody.
+ }
+ return null
+ }
+
+ /** Returns the Prefix Information Options (PIO) extracted from an ICMPv6 RA message. */
+ @JvmStatic
+ fun getRaPios(raMsg: ByteArray?): List<PrefixInformationOption> {
+ val pioList = ArrayList<PrefixInformationOption>()
+
+ raMsg ?: return pioList
+
+ val buf = ByteBuffer.wrap(raMsg)
+ val ipv6Header = Struct.parse(Ipv6Header::class.java, buf)
+ if (ipv6Header.nextHeader != OsConstants.IPPROTO_ICMPV6.toByte()) {
+ return pioList
+ }
+
+ val icmpv6Header = Struct.parse(Icmpv6Header::class.java, buf)
+ if (icmpv6Header.type != NetworkStackConstants.ICMPV6_ROUTER_ADVERTISEMENT.toShort()) {
+ return pioList
+ }
+
+ Struct.parse(RaHeader::class.java, buf)
+ while (buf.position() < raMsg.size) {
+ val currentPos = buf.position()
+ val type = toUnsignedInt(buf.get())
+ val length = toUnsignedInt(buf.get())
+ if (type == NetworkStackConstants.ICMPV6_ND_OPTION_PIO) {
+ val pioBuf = ByteBuffer.wrap(
+ buf.array(), currentPos, Struct.getSize(PrefixInformationOption::class.java)
+ )
+ val pio = Struct.parse(PrefixInformationOption::class.java, pioBuf)
+ pioList.add(pio)
+
+ // Move ByteBuffer position to the next option.
+ buf.position(
+ currentPos + Struct.getSize(PrefixInformationOption::class.java)
+ )
+ } else {
+ // The length is in units of 8 octets.
+ buf.position(currentPos + (length * 8))
+ }
+ }
+ return pioList
+ }
+
+ /**
+ * Sends a UDP message to a destination.
+ *
+ * @param dstAddress the IP address of the destination
+ * @param dstPort the port of the destination
+ * @param message the message in UDP payload
+ * @throws IOException if failed to send the message
+ */
+ @JvmStatic
+ @Throws(IOException::class)
+ fun sendUdpMessage(dstAddress: InetAddress, dstPort: Int, message: String) {
+ val dstSockAddr: SocketAddress = InetSocketAddress(dstAddress, dstPort)
+
+ DatagramSocket().use { socket ->
+ socket.connect(dstSockAddr)
+ val msgBytes = message.toByteArray()
+ val packet = DatagramPacket(msgBytes, msgBytes.size)
+ socket.send(packet)
+ }
+ }
+
+ @JvmStatic
+ fun isInMulticastGroup(interfaceName: String, address: Inet6Address): Boolean {
+ val cmd = "ip -6 maddr show dev $interfaceName"
+ val output: String = runShellCommandOrThrow(cmd)
+ val addressStr = address.hostAddress
+ for (line in output.split("\\n".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()) {
+ if (line.contains(addressStr)) {
+ return true
+ }
+ }
+ return false
+ }
+
+ @JvmStatic
+ fun getIpv6LinkAddresses(interfaceName: String): List<LinkAddress> {
+ val addresses: MutableList<LinkAddress> = ArrayList()
+ val cmd = " ip -6 addr show dev $interfaceName"
+ val output: String = runShellCommandOrThrow(cmd)
+
+ for (line in output.split("\\n".toRegex()).dropLastWhile { it.isEmpty() }.toTypedArray()) {
+ if (line.contains("inet6")) {
+ addresses.add(parseAddressLine(line))
+ }
+ }
+
+ return addresses
+ }
+
+ /** Return the first discovered service of `serviceType`. */
+ @JvmStatic
+ @Throws(Exception::class)
+ fun discoverService(nsdManager: NsdManager, serviceType: String): NsdServiceInfo {
+ val serviceInfoFuture = CompletableFuture<NsdServiceInfo>()
+ val listener: NsdManager.DiscoveryListener = object : DefaultDiscoveryListener() {
+ override fun onServiceFound(serviceInfo: NsdServiceInfo) {
+ serviceInfoFuture.complete(serviceInfo)
+ }
+ }
+ nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener)
+ try {
+ serviceInfoFuture[SERVICE_DISCOVERY_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS]
+ } finally {
+ nsdManager.stopServiceDiscovery(listener)
+ }
+
+ return serviceInfoFuture.get()
+ }
+
+ /**
+ * Returns the [NsdServiceInfo] when a service instance of `serviceType` gets lost.
+ */
+ @JvmStatic
+ fun discoverForServiceLost(
+ nsdManager: NsdManager,
+ serviceType: String?,
+ serviceInfoFuture: CompletableFuture<NsdServiceInfo?>
+ ): NsdManager.DiscoveryListener {
+ val listener: NsdManager.DiscoveryListener = object : DefaultDiscoveryListener() {
+ override fun onServiceLost(serviceInfo: NsdServiceInfo): Unit {
+ serviceInfoFuture.complete(serviceInfo)
+ }
+ }
+ nsdManager.discoverServices(serviceType, NsdManager.PROTOCOL_DNS_SD, listener)
+ return listener
+ }
+
+ /** Resolves the service. */
+ @JvmStatic
+ @Throws(Exception::class)
+ fun resolveService(nsdManager: NsdManager, serviceInfo: NsdServiceInfo): NsdServiceInfo {
+ return resolveServiceUntil(nsdManager, serviceInfo) { true }
+ }
+
+ /** Returns the first resolved service that satisfies the `predicate`. */
+ @JvmStatic
+ @Throws(Exception::class)
+ fun resolveServiceUntil(
+ nsdManager: NsdManager, serviceInfo: NsdServiceInfo, predicate: Predicate<NsdServiceInfo>
+ ): NsdServiceInfo {
+ val resolvedServiceInfoFuture = CompletableFuture<NsdServiceInfo>()
+ val callback: NsdManager.ServiceInfoCallback = object : DefaultServiceInfoCallback() {
+ override fun onServiceUpdated(serviceInfo: NsdServiceInfo) {
+ if (predicate.test(serviceInfo)) {
+ resolvedServiceInfoFuture.complete(serviceInfo)
+ }
+ }
+ }
+ nsdManager.registerServiceInfoCallback(serviceInfo, directExecutor(), callback)
+ try {
+ return resolvedServiceInfoFuture[
+ SERVICE_DISCOVERY_TIMEOUT.toMillis(),
+ TimeUnit.MILLISECONDS]
+ } finally {
+ nsdManager.unregisterServiceInfoCallback(callback)
+ }
+ }
+
+ @JvmStatic
+ fun getPrefixesFromNetData(netData: String): String {
+ val startIdx = netData.indexOf("Prefixes:")
+ val endIdx = netData.indexOf("Routes:")
+ return netData.substring(startIdx, endIdx)
+ }
+
+ @JvmStatic
+ @Throws(Exception::class)
+ fun getThreadNetwork(timeout: Duration): Network {
+ val networkFuture = CompletableFuture<Network>()
+ val cm =
+ ApplicationProvider.getApplicationContext<Context>()
+ .getSystemService(ConnectivityManager::class.java)
+ val networkRequestBuilder =
+ NetworkRequest.Builder().addTransportType(NetworkCapabilities.TRANSPORT_THREAD)
+ // Before V, we need to explicitly set `NET_CAPABILITY_LOCAL_NETWORK` capability to request
+ // a Thread network.
+ if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.UPSIDE_DOWN_CAKE) {
+ networkRequestBuilder.addCapability(NetworkCapabilities.NET_CAPABILITY_LOCAL_NETWORK)
+ }
+ val networkRequest = networkRequestBuilder.build()
+ val networkCallback: ConnectivityManager.NetworkCallback =
+ object : ConnectivityManager.NetworkCallback() {
+ override fun onAvailable(network: Network) {
+ networkFuture.complete(network)
+ }
+ }
+ cm.registerNetworkCallback(networkRequest, networkCallback)
+ return networkFuture[timeout.toSeconds(), TimeUnit.SECONDS]
+ }
+
+ /**
+ * Let the FTD join the specified Thread network and wait for border routing to be available.
+ *
+ * @return the OMR address
+ */
+ @JvmStatic
+ @Throws(Exception::class)
+ fun joinNetworkAndWaitForOmr(
+ ftd: FullThreadDevice, dataset: ActiveOperationalDataset
+ ): Inet6Address {
+ ftd.factoryReset()
+ ftd.joinNetwork(dataset)
+ ftd.waitForStateAnyOf(listOf("router", "child"), JOIN_TIMEOUT)
+ waitFor({ ftd.omrAddress != null }, Duration.ofSeconds(60))
+ Assert.assertNotNull(ftd.omrAddress)
+ return ftd.omrAddress
+ }
+
+ private open class DefaultDiscoveryListener : NsdManager.DiscoveryListener {
+ override fun onStartDiscoveryFailed(serviceType: String, errorCode: Int) {}
+ override fun onStopDiscoveryFailed(serviceType: String, errorCode: Int) {}
+ override fun onDiscoveryStarted(serviceType: String) {}
+ override fun onDiscoveryStopped(serviceType: String) {}
+ override fun onServiceFound(serviceInfo: NsdServiceInfo) {}
+ override fun onServiceLost(serviceInfo: NsdServiceInfo) {}
+ }
+
+ private open class DefaultServiceInfoCallback : NsdManager.ServiceInfoCallback {
+ override fun onServiceInfoCallbackRegistrationFailed(errorCode: Int) {}
+ override fun onServiceUpdated(serviceInfo: NsdServiceInfo) {}
+ override fun onServiceLost(): Unit {}
+ override fun onServiceInfoCallbackUnregistered() {}
+ }
+
+ /**
+ * Parses a line of output from "ip -6 addr show" into a [LinkAddress].
+ *
+ * Example line: "inet6 2001:db8:1:1::1/64 scope global deprecated"
+ */
+ private fun parseAddressLine(line: String): LinkAddress {
+ val parts = line.split("\\s+".toRegex()).filter { it.isNotEmpty() }.toTypedArray()
+ val addressString = parts[1]
+ val pieces = addressString.split("/".toRegex(), limit = 2).toTypedArray()
+ val prefixLength = pieces[1].toInt()
+ val address = parseNumericAddress(pieces[0])
+ val deprecationTimeMillis =
+ if (line.contains("deprecated")) SystemClock.elapsedRealtime()
+ else LinkAddress.LIFETIME_PERMANENT
+
+ return LinkAddress(
+ address, prefixLength,
+ 0 /* flags */, 0 /* scope */,
+ deprecationTimeMillis, LinkAddress.LIFETIME_PERMANENT /* expirationTime */
+ )
+ }
+
+ @JvmStatic
+ @JvmOverloads
+ fun startInfraDeviceAndWaitForOnLinkAddr(
+ tapPacketReader: TapPacketReader,
+ macAddress: MacAddress = MacAddress.fromString("1:2:3:4:5:6")
+ ): InfraNetworkDevice {
+ val infraDevice = InfraNetworkDevice(macAddress, tapPacketReader)
+ infraDevice.runSlaac(Duration.ofSeconds(60))
+ requireNotNull(infraDevice.ipv6Addr)
+ return infraDevice
+ }
+
+ @JvmStatic
+ @Throws(java.lang.Exception::class)
+ fun setUpInfraNetwork(
+ context: Context, controller: ThreadNetworkControllerWrapper
+ ): TestNetworkTracker {
+ val lp = LinkProperties()
+
+ // TODO: use a fake DNS server
+ lp.setDnsServers(listOf(parseNumericAddress("8.8.8.8")))
+ // NAT64 feature requires the infra network to have an IPv4 default route.
+ lp.addRoute(
+ RouteInfo(
+ IpPrefix("0.0.0.0/0") /* destination */,
+ null /* gateway */,
+ null /* iface */,
+ RouteInfo.RTN_UNICAST, 1500 /* mtu */
+ )
+ )
+ val infraNetworkTracker: TestNetworkTracker =
+ runAsShell(
+ MANAGE_TEST_NETWORKS,
+ supplier = { initTestNetwork(context, lp, setupTimeoutMs = 5000) })
+ val infraNetworkName: String = infraNetworkTracker.testIface.getInterfaceName()
+ controller.setTestNetworkAsUpstreamAndWait(infraNetworkName)
+
+ return infraNetworkTracker
+ }
+
+ @JvmStatic
+ fun tearDownInfraNetwork(testNetworkTracker: TestNetworkTracker) {
+ runAsShell(MANAGE_TEST_NETWORKS) { testNetworkTracker.teardown() }
+ }
+}