Merge "Revert "[ST03] Add test dns server for integration tests""
diff --git a/staticlibs/device/com/android/net/module/util/NetworkMonitorUtils.java b/staticlibs/device/com/android/net/module/util/NetworkMonitorUtils.java
index f6cd044..5a4412f 100644
--- a/staticlibs/device/com/android/net/module/util/NetworkMonitorUtils.java
+++ b/staticlibs/device/com/android/net/module/util/NetworkMonitorUtils.java
@@ -16,6 +16,7 @@
package com.android.net.module.util;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_DUN;
import static android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET;
import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED;
import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VPN;
@@ -102,9 +103,12 @@
* networks.
* @param nc Network capabilities of the network to test.
*/
- public static boolean isValidationRequired(boolean isVpnValidationRequired,
+ public static boolean isValidationRequired(boolean isDunValidationRequired,
+ boolean isVpnValidationRequired,
@NonNull final NetworkCapabilities nc) {
- // TODO: Consider requiring validation for DUN networks.
+ if (isDunValidationRequired && nc.hasCapability(NET_CAPABILITY_DUN)) {
+ return true;
+ }
if (!nc.hasCapability(NET_CAPABILITY_NOT_VPN)) {
return isVpnValidationRequired;
}
diff --git a/staticlibs/framework/com/android/net/module/util/BitUtils.java b/staticlibs/framework/com/android/net/module/util/BitUtils.java
new file mode 100644
index 0000000..2b32e86
--- /dev/null
+++ b/staticlibs/framework/com/android/net/module/util/BitUtils.java
@@ -0,0 +1,110 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util;
+
+import android.annotation.NonNull;
+
+/**
+ * @hide
+ */
+public class BitUtils {
+ /**
+ * Unpacks long value into an array of bits.
+ */
+ public static int[] unpackBits(long val) {
+ int size = Long.bitCount(val);
+ int[] result = new int[size];
+ int index = 0;
+ int bitPos = 0;
+ while (val != 0) {
+ if ((val & 1) == 1) result[index++] = bitPos;
+ val = val >>> 1;
+ bitPos++;
+ }
+ return result;
+ }
+
+ /**
+ * Packs a list of ints in the same way as packBits()
+ *
+ * Each passed int is the rank of a bit that should be set in the returned long.
+ * Example : passing (1,3) will return in 0b00001010 and passing (5,6,0) will return 0b01100001
+ *
+ * @param bits bits to pack
+ * @return a long with the specified bits set.
+ */
+ public static long packBitList(int... bits) {
+ return packBits(bits);
+ }
+
+ /**
+ * Packs array of bits into a long value.
+ *
+ * Each passed int is the rank of a bit that should be set in the returned long.
+ * Example : passing [1,3] will return in 0b00001010 and passing [5,6,0] will return 0b01100001
+ *
+ * @param bits bits to pack
+ * @return a long with the specified bits set.
+ */
+ public static long packBits(int[] bits) {
+ long packed = 0;
+ for (int b : bits) {
+ packed |= (1L << b);
+ }
+ return packed;
+ }
+
+ /**
+ * An interface for a function that can retrieve a name associated with an int.
+ *
+ * This is useful for bitfields like network capabilities or network score policies.
+ */
+ @FunctionalInterface
+ public interface NameOf {
+ /** Retrieve the name associated with the passed value */
+ String nameOf(int value);
+ }
+
+ /**
+ * Given a bitmask and a name fetcher, append names of all set bits to the builder
+ *
+ * This method takes all bit sets in the passed bitmask, will figure out the name associated
+ * with the weight of each bit with the passed name fetcher, and append each name to the
+ * passed StringBuilder, separated by the passed separator.
+ *
+ * For example, if the bitmask is 0110, and the name fetcher return "BIT_1" to "BIT_4" for
+ * numbers from 1 to 4, and the separator is "&", this method appends "BIT_2&BIT3" to the
+ * StringBuilder.
+ */
+ public static void appendStringRepresentationOfBitMaskToStringBuilder(@NonNull StringBuilder sb,
+ long bitMask, @NonNull NameOf nameFetcher, @NonNull String separator) {
+ int bitPos = 0;
+ boolean firstElementAdded = false;
+ while (bitMask != 0) {
+ if ((bitMask & 1) != 0) {
+ if (firstElementAdded) {
+ sb.append(separator);
+ } else {
+ firstElementAdded = true;
+ }
+ sb.append(nameFetcher.nameOf(bitPos));
+ }
+ bitMask >>>= 1;
+ ++bitPos;
+ }
+ }
+}
diff --git a/staticlibs/framework/com/android/net/module/util/ByteUtils.java b/staticlibs/framework/com/android/net/module/util/ByteUtils.java
new file mode 100644
index 0000000..290ed46
--- /dev/null
+++ b/staticlibs/framework/com/android/net/module/util/ByteUtils.java
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util;
+
+import android.annotation.NonNull;
+
+/**
+ * Byte utility functions.
+ * @hide
+ */
+public class ByteUtils {
+ /**
+ * Returns the index of the first appearance of the value {@code target} in {@code array}.
+ *
+ * @param array an array of {@code byte} values, possibly empty
+ * @param target a primitive {@code byte} value
+ * @return the least index {@code i} for which {@code array[i] == target}, or {@code -1} if no
+ * such index exists.
+ */
+ public static int indexOf(@NonNull byte[] array, byte target) {
+ return indexOf(array, target, 0, array.length);
+ }
+
+ private static int indexOf(byte[] array, byte target, int start, int end) {
+ for (int i = start; i < end; i++) {
+ if (array[i] == target) {
+ return i;
+ }
+ }
+ return -1;
+ }
+
+ /**
+ * Returns the values from each provided array combined into a single array. For example, {@code
+ * concat(new byte[] {a, b}, new byte[] {}, new byte[] {c}} returns the array {@code {a, b, c}}.
+ *
+ * @param arrays zero or more {@code byte} arrays
+ * @return a single array containing all the values from the source arrays, in order
+ */
+ public static byte[] concat(@NonNull byte[]... arrays) {
+ int length = 0;
+ for (byte[] array : arrays) {
+ length += array.length;
+ }
+ byte[] result = new byte[length];
+ int pos = 0;
+ for (byte[] array : arrays) {
+ System.arraycopy(array, 0, result, pos, array.length);
+ pos += array.length;
+ }
+ return result;
+ }
+}
diff --git a/staticlibs/framework/com/android/net/module/util/CollectionUtils.java b/staticlibs/framework/com/android/net/module/util/CollectionUtils.java
index 7cac90d..f08880c 100644
--- a/staticlibs/framework/com/android/net/module/util/CollectionUtils.java
+++ b/staticlibs/framework/com/android/net/module/util/CollectionUtils.java
@@ -18,12 +18,15 @@
import android.annotation.NonNull;
import android.annotation.Nullable;
+import android.util.ArrayMap;
+import android.util.Pair;
import android.util.SparseArray;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Objects;
+import java.util.function.Function;
import java.util.function.Predicate;
/**
@@ -291,4 +294,100 @@
@NonNull final Predicate<? super T> condition) {
return -1 != indexOf(haystack, condition);
}
+
+ /**
+ * Standard map function, but returns a new modifiable ArrayList
+ *
+ * This returns a new list that contains, for each element of the source collection, its
+ * image through the passed transform.
+ * Elements in the source can be null if the transform accepts null inputs.
+ * Elements in the output can be null if the transform ever returns null.
+ * This function never returns null. If the source collection is empty, it returns the
+ * empty list.
+ * Contract : this method calls the transform function exactly once for each element in the
+ * list, in iteration order.
+ *
+ * @param source the source collection
+ * @param transform the function to transform the elements
+ * @param <T> type of source elements
+ * @param <R> type of destination elements
+ * @return an unmodifiable list of transformed elements
+ */
+ @NonNull
+ public static <T, R> ArrayList<R> map(@NonNull final Collection<T> source,
+ @NonNull final Function<? super T, ? extends R> transform) {
+ final ArrayList<R> dest = new ArrayList<>(source.size());
+ for (final T e : source) {
+ dest.add(transform.apply(e));
+ }
+ return dest;
+ }
+
+ /**
+ * Standard zip function, but returns a new modifiable ArrayList
+ *
+ * This returns a list of pairs containing, at each position, a pair of the element from the
+ * first list at that index and the element from the second list at that index.
+ * Both lists must be the same size. They may contain null.
+ *
+ * The easiest way to visualize what's happening is to think of two lists being laid out next
+ * to each other and stitched together with a zipper.
+ *
+ * Contract : this method will read each element of each list exactly once, in some unspecified
+ * order. If it throws, it will not read any element.
+ *
+ * @param first the first list of elements
+ * @param second the second list of elements
+ * @param <T> the type of first elements
+ * @param <R> the type of second elements
+ * @return the zipped list
+ */
+ @NonNull
+ public static <T, R> ArrayList<Pair<T, R>> zip(@NonNull final List<T> first,
+ @NonNull final List<R> second) {
+ final int size = first.size();
+ if (size != second.size()) {
+ throw new IllegalArgumentException("zip : collections must be the same size");
+ }
+ final ArrayList<Pair<T, R>> dest = new ArrayList<>(size);
+ for (int i = 0; i < size; ++i) {
+ dest.add(new Pair<>(first.get(i), second.get(i)));
+ }
+ return dest;
+ }
+
+ /**
+ * Returns a new ArrayMap that associates each key with the value at the same index.
+ *
+ * Both lists must be the same size.
+ * Both keys and values may contain null.
+ * Keys may not contain the same value twice.
+ *
+ * Contract : this method will read each element of each list exactly once, but does not
+ * specify the order, except if it throws in which case the number of reads is undefined.
+ *
+ * @param keys The list of keys
+ * @param values The list of values
+ * @param <T> The type of keys
+ * @param <R> The type of values
+ * @return The associated map
+ */
+ @NonNull
+ public static <T, R> ArrayMap<T, R> assoc(
+ @NonNull final List<T> keys, @NonNull final List<R> values) {
+ final int size = keys.size();
+ if (size != values.size()) {
+ throw new IllegalArgumentException("assoc : collections must be the same size");
+ }
+ final ArrayMap<T, R> dest = new ArrayMap<>(size);
+ for (int i = 0; i < size; ++i) {
+ final T key = keys.get(i);
+ if (dest.containsKey(key)) {
+ throw new IllegalArgumentException(
+ "assoc : keys may not contain the same value twice");
+ }
+ dest.put(key, values.get(i));
+ }
+ return dest;
+ }
}
diff --git a/staticlibs/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java b/staticlibs/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java
index 26c24f8..54ce01e 100644
--- a/staticlibs/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java
+++ b/staticlibs/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java
@@ -44,6 +44,9 @@
import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
import static android.net.NetworkCapabilities.TRANSPORT_WIFI_AWARE;
+import static com.android.net.module.util.BitUtils.packBitList;
+import static com.android.net.module.util.BitUtils.unpackBits;
+
import android.annotation.NonNull;
import android.net.NetworkCapabilities;
@@ -181,49 +184,4 @@
return false;
}
- /**
- * Unpacks long value into an array of bits.
- */
- public static int[] unpackBits(long val) {
- int size = Long.bitCount(val);
- int[] result = new int[size];
- int index = 0;
- int bitPos = 0;
- while (val != 0) {
- if ((val & 1) == 1) result[index++] = bitPos;
- val = val >>> 1;
- bitPos++;
- }
- return result;
- }
-
- /**
- * Packs a list of ints in the same way as packBits()
- *
- * Each passed int is the rank of a bit that should be set in the returned long.
- * Example : passing (1,3) will return in 0b00001010 and passing (5,6,0) will return 0b01100001
- *
- * @param bits bits to pack
- * @return a long with the specified bits set.
- */
- public static long packBitList(int... bits) {
- return packBits(bits);
- }
-
- /**
- * Packs array of bits into a long value.
- *
- * Each passed int is the rank of a bit that should be set in the returned long.
- * Example : passing [1,3] will return in 0b00001010 and passing [5,6,0] will return 0b01100001
- *
- * @param bits bits to pack
- * @return a long with the specified bits set.
- */
- public static long packBits(int[] bits) {
- long packed = 0;
- for (int b : bits) {
- packed |= (1L << b);
- }
- return packed;
- }
}
diff --git a/staticlibs/native/bpf_headers/include/bpf/BpfUtils.h b/staticlibs/native/bpf_headers/include/bpf/BpfUtils.h
index 4429164..157f210 100644
--- a/staticlibs/native/bpf_headers/include/bpf/BpfUtils.h
+++ b/staticlibs/native/bpf_headers/include/bpf/BpfUtils.h
@@ -115,23 +115,16 @@
return kernelVersion() >= KVER(major, minor, sub);
}
-#define SKIP_IF_BPF_SUPPORTED \
- do { \
- if (android::bpf::isAtLeastKernelVersion(4, 9, 0)) \
- GTEST_SKIP() << "Skip: bpf is supported."; \
- } while (0)
-
#define SKIP_IF_BPF_NOT_SUPPORTED \
do { \
if (!android::bpf::isAtLeastKernelVersion(4, 9, 0)) \
GTEST_SKIP() << "Skip: bpf is not supported."; \
} while (0)
-#define SKIP_IF_EXTENDED_BPF_NOT_SUPPORTED \
- do { \
- if (!android::bpf::isAtLeastKernelVersion(4, 14, 0)) \
- GTEST_SKIP() << "Skip: extended bpf feature not supported."; \
- } while (0)
+// Only used by tm-mainline-prod's system/netd/tests/bpf_base_test.cpp
+// but platform and platform tests aren't expected to build/work in tm-mainline-prod
+// so we can just trivialize this
+#define SKIP_IF_EXTENDED_BPF_NOT_SUPPORTED
#define SKIP_IF_XDP_NOT_SUPPORTED \
do { \
diff --git a/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h b/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
index 35d9f31..c652c76 100644
--- a/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
+++ b/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
@@ -147,6 +147,23 @@
__attribute__ ((section(".maps." #name), used)) \
____btf_map_##name = { }
+/* There exist buggy kernels with pre-T OS, that due to
+ * kernel patch "[ALPS05162612] bpf: fix ubsan error"
+ * do not support userspace writes into non-zero index of bpf map arrays.
+ *
+ * We use this assert to prevent us from being able to define such a map.
+ */
+
+#ifdef THIS_BPF_PROGRAM_IS_FOR_TEST_PURPOSES_ONLY
+#define BPF_MAP_ASSERT_OK(type, entries, mode)
+#elif BPFLOADER_MIN_VER >= BPFLOADER_T_BETA3_VERSION
+#define BPF_MAP_ASSERT_OK(type, entries, mode)
+#else
+#define BPF_MAP_ASSERT_OK(type, entries, mode) \
+ _Static_assert(((type) != BPF_MAP_TYPE_ARRAY) || ((entries) <= 1) || !((mode) & 0222), \
+ "Writable arrays with more than 1 element not supported on pre-T devices.")
+#endif
+
/* type safe macro to declare a map and related accessor functions */
#define DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md, \
selinux, pindir, share) \
@@ -167,6 +184,7 @@
.pin_subdir = pindir, \
.shared = share, \
}; \
+ BPF_MAP_ASSERT_OK(BPF_MAP_TYPE_##TYPE, (num_entries), (md)); \
BPF_ANNOTATE_KV_PAIR(the_map, KeyType, ValueType); \
\
static inline __always_inline __unused ValueType* bpf_##the_map##_lookup_elem( \
@@ -205,6 +223,10 @@
DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \
DEFAULT_BPF_MAP_UID, AID_ROOT, 0600)
+#define DEFINE_BPF_MAP_RO(the_map, TYPE, KeyType, ValueType, num_entries, gid) \
+ DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \
+ DEFAULT_BPF_MAP_UID, gid, 0440)
+
#define DEFINE_BPF_MAP_GWO(the_map, TYPE, KeyType, ValueType, num_entries, gid) \
DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \
DEFAULT_BPF_MAP_UID, gid, 0620)
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt b/staticlibs/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt
new file mode 100644
index 0000000..0236716
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util
+
+import com.android.net.module.util.BitUtils.appendStringRepresentationOfBitMaskToStringBuilder
+import com.android.net.module.util.BitUtils.packBits
+import com.android.net.module.util.BitUtils.unpackBits
+import org.junit.Test
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
+
+class BitUtilsTests {
+ @Test
+ fun testBitPackingTestCase() {
+ runBitPackingTestCase(0, intArrayOf())
+ runBitPackingTestCase(1, intArrayOf(0))
+ runBitPackingTestCase(3, intArrayOf(0, 1))
+ runBitPackingTestCase(4, intArrayOf(2))
+ runBitPackingTestCase(63, intArrayOf(0, 1, 2, 3, 4, 5))
+ runBitPackingTestCase(Long.MAX_VALUE.inv(), intArrayOf(63))
+ runBitPackingTestCase(Long.MAX_VALUE.inv() + 1, intArrayOf(0, 63))
+ runBitPackingTestCase(Long.MAX_VALUE.inv() + 2, intArrayOf(1, 63))
+ }
+
+ fun runBitPackingTestCase(packedBits: Long, bits: IntArray) {
+ assertEquals(packedBits, packBits(bits))
+ assertTrue(bits contentEquals unpackBits(packedBits))
+ }
+
+ @Test
+ fun testAppendStringRepresentationOfBitMaskToStringBuilder() {
+ runTestAppendStringRepresentationOfBitMaskToStringBuilder("", 0)
+ runTestAppendStringRepresentationOfBitMaskToStringBuilder("BIT0", 0b1)
+ runTestAppendStringRepresentationOfBitMaskToStringBuilder("BIT1&BIT2&BIT4", 0b10110)
+ runTestAppendStringRepresentationOfBitMaskToStringBuilder(
+ "BIT0&BIT60&BIT61&BIT62&BIT63",
+ (0b11110000_00000000_00000000_00000000 shl 32) +
+ 0b00000000_00000000_00000000_00000001)
+ }
+
+ fun runTestAppendStringRepresentationOfBitMaskToStringBuilder(expected: String, bitMask: Long) {
+ StringBuilder().let {
+ appendStringRepresentationOfBitMaskToStringBuilder(it, bitMask, { i -> "BIT$i" }, "&")
+ assertEquals(expected, it.toString())
+ }
+ }
+}
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/ByteUtilsTests.kt b/staticlibs/tests/unit/src/com/android/net/module/util/ByteUtilsTests.kt
new file mode 100644
index 0000000..e58adad
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/ByteUtilsTests.kt
@@ -0,0 +1,54 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util
+
+import com.android.net.module.util.ByteUtils.indexOf
+import com.android.net.module.util.ByteUtils.concat
+import org.junit.Test
+import kotlin.test.assertContentEquals
+import kotlin.test.assertEquals
+import kotlin.test.assertNotSame
+
+class ByteUtilsTests {
+ private val EMPTY = byteArrayOf()
+ private val ARRAY1 = byteArrayOf(1)
+ private val ARRAY234 = byteArrayOf(2, 3, 4)
+
+ @Test
+ fun testIndexOf() {
+ assertEquals(-1, indexOf(EMPTY, 1))
+ assertEquals(-1, indexOf(ARRAY1, 2))
+ assertEquals(-1, indexOf(ARRAY234, 1))
+ assertEquals(0, indexOf(byteArrayOf(-1), -1))
+ assertEquals(0, indexOf(ARRAY234, 2))
+ assertEquals(1, indexOf(ARRAY234, 3))
+ assertEquals(2, indexOf(ARRAY234, 4))
+ assertEquals(1, indexOf(byteArrayOf(2, 3, 2, 3), 3))
+ }
+
+ @Test
+ fun testConcat() {
+ assertContentEquals(EMPTY, concat())
+ assertContentEquals(EMPTY, concat(EMPTY))
+ assertContentEquals(EMPTY, concat(EMPTY, EMPTY, EMPTY))
+ assertContentEquals(ARRAY1, concat(ARRAY1))
+ assertNotSame(ARRAY1, concat(ARRAY1))
+ assertContentEquals(ARRAY1, concat(EMPTY, ARRAY1, EMPTY))
+ assertContentEquals(byteArrayOf(1, 1, 1), concat(ARRAY1, ARRAY1, ARRAY1))
+ assertContentEquals(byteArrayOf(1, 2, 3, 4), concat(ARRAY1, ARRAY234))
+ }
+}
\ No newline at end of file
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/CollectionUtilsTest.kt b/staticlibs/tests/unit/src/com/android/net/module/util/CollectionUtilsTest.kt
index 0f00d0b..9fb025b 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/CollectionUtilsTest.kt
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/CollectionUtilsTest.kt
@@ -22,6 +22,7 @@
import org.junit.Test
import org.junit.runner.RunWith
import kotlin.test.assertEquals
+import kotlin.test.assertFailsWith
import kotlin.test.assertFalse
import kotlin.test.assertNull
import kotlin.test.assertSame
@@ -133,4 +134,52 @@
assertSame(CollectionUtils.findLast(listMulti) { it == "E" }, listMulti[7])
assertNull(CollectionUtils.findLast(listMulti) { it == "F" })
}
+
+ @Test
+ fun testMap() {
+ val listAE = listOf("A", "B", "C", "D", "E", null)
+ assertEquals(listAE.map { "-$it-" }, CollectionUtils.map(listAE) { "-$it-" })
+ }
+
+ @Test
+ fun testZip() {
+ val listAE = listOf("A", "B", "C", "D", "E")
+ val list15 = listOf(1, 2, 3, 4, 5)
+ // Normal #zip returns kotlin.Pair, not android.util.Pair
+ assertEquals(list15.zip(listAE).map { android.util.Pair(it.first, it.second) },
+ CollectionUtils.zip(list15, listAE))
+ val listNull = listOf("A", null, "B", "C", "D")
+ assertEquals(list15.zip(listNull).map { android.util.Pair(it.first, it.second) },
+ CollectionUtils.zip(list15, listNull))
+ assertEquals(emptyList<android.util.Pair<Int, Int>>(),
+ CollectionUtils.zip(emptyList<Int>(), emptyList<Int>()))
+ assertFailsWith<IllegalArgumentException> {
+ // Different size
+ CollectionUtils.zip(listOf(1, 2), list15)
+ }
+ }
+
+ @Test
+ fun testAssoc() {
+ val listADA = listOf("A", "B", "C", "D", "A")
+ val list15 = listOf(1, 2, 3, 4, 5)
+ assertEquals(list15.zip(listADA).toMap(), CollectionUtils.assoc(list15, listADA))
+
+ // Null key is fine
+ val assoc = CollectionUtils.assoc(listOf(1, 2, null), listOf("A", "B", "C"))
+ assertEquals("C", assoc[null])
+
+ assertFailsWith<IllegalArgumentException> {
+ // Same key multiple times
+ CollectionUtils.assoc(listOf("A", "B", "A"), listOf(1, 2, 3))
+ }
+ assertFailsWith<IllegalArgumentException> {
+ // Same key multiple times, but it's null
+ CollectionUtils.assoc(listOf(null, "B", null), listOf(1, 2, 3))
+ }
+ assertFailsWith<IllegalArgumentException> {
+ // Different size
+ CollectionUtils.assoc(listOf(1, 2), list15)
+ }
+ }
}
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt b/staticlibs/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt
index 256ea1e..958f45f 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt
@@ -22,7 +22,6 @@
import android.net.NetworkCapabilities.NET_CAPABILITY_CBS
import android.net.NetworkCapabilities.NET_CAPABILITY_EIMS
import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET
-import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET
import android.net.NetworkCapabilities.NET_CAPABILITY_OEM_PAID
import android.net.NetworkCapabilities.TRANSPORT_BLUETOOTH
import android.net.NetworkCapabilities.TRANSPORT_CELLULAR
@@ -38,8 +37,6 @@
import com.android.net.module.util.NetworkCapabilitiesUtils.RESTRICTED_CAPABILITIES
import com.android.net.module.util.NetworkCapabilitiesUtils.UNRESTRICTED_CAPABILITIES
import com.android.net.module.util.NetworkCapabilitiesUtils.getDisplayTransport
-import com.android.net.module.util.NetworkCapabilitiesUtils.packBits
-import com.android.net.module.util.NetworkCapabilitiesUtils.unpackBits
import org.junit.Test
import org.junit.runner.RunWith
import java.lang.IllegalArgumentException
@@ -75,23 +72,6 @@
}
}
- @Test
- fun testBitPackingTestCase() {
- runBitPackingTestCase(0, intArrayOf())
- runBitPackingTestCase(1, intArrayOf(0))
- runBitPackingTestCase(3, intArrayOf(0, 1))
- runBitPackingTestCase(4, intArrayOf(2))
- runBitPackingTestCase(63, intArrayOf(0, 1, 2, 3, 4, 5))
- runBitPackingTestCase(Long.MAX_VALUE.inv(), intArrayOf(63))
- runBitPackingTestCase(Long.MAX_VALUE.inv() + 1, intArrayOf(0, 63))
- runBitPackingTestCase(Long.MAX_VALUE.inv() + 2, intArrayOf(1, 63))
- }
-
- fun runBitPackingTestCase(packedBits: Long, bits: IntArray) {
- assertEquals(packedBits, packBits(bits))
- assertTrue(bits contentEquals unpackBits(packedBits))
- }
-
// NetworkCapabilities constructor and Builder are not available until R. Mark TargetApi to
// ignore the linter error since it's used in only unit test.
@Test @TargetApi(Build.VERSION_CODES.R)
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt b/staticlibs/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt
index 9fb4d8c..8e320d0 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt
@@ -17,6 +17,7 @@
package com.android.net.module.util
import com.android.testutils.ConcurrentInterpreter
+import com.android.testutils.INTERPRET_TIME_UNIT
import com.android.testutils.InterpretException
import com.android.testutils.InterpretMatcher
import com.android.testutils.SyntaxException
@@ -420,7 +421,7 @@
// the test code to not compile instead of throw, but it's vastly more complex and this will
// fail 100% at runtime any test that would not have compiled.
Regex("""poll\((\d+)?\)\s*(\{.*\})?""") to { i, t, r ->
- (if (r.strArg(1).isEmpty()) i.interpretTimeUnit else r.timeArg(1)).let { time ->
+ (if (r.strArg(1).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(1)).let { time ->
(t as ArrayTrackRecord<Int>.ReadHead).poll(time, makePredicate(r.strArg(2)))
}
},
diff --git a/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
new file mode 100644
index 0000000..f8f2da0
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
@@ -0,0 +1,477 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.annotation.SuppressLint
+import android.net.LinkAddress
+import android.net.LinkProperties
+import android.net.Network
+import android.net.NetworkCapabilities
+import com.android.testutils.RecorderCallback.CallbackEntry
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.AVAILABLE
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.BLOCKED_STATUS
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LINK_PROPERTIES_CHANGED
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LOSING
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.NETWORK_CAPS_UPDATED
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LOST
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.RESUMED
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.SUSPENDED
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.UNAVAILABLE
+import com.android.testutils.RecorderCallback.CallbackEntry.Available
+import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
+import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.junit.Assume.assumeTrue
+import kotlin.reflect.KClass
+import kotlin.test.assertEquals
+import kotlin.test.assertFails
+import kotlin.test.assertNull
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+const val SHORT_TIMEOUT_MS = 20L
+const val DEFAULT_LINGER_DELAY_MS = 30000
+const val NOT_METERED = NetworkCapabilities.NET_CAPABILITY_NOT_METERED
+const val WIFI = NetworkCapabilities.TRANSPORT_WIFI
+const val CELLULAR = NetworkCapabilities.TRANSPORT_CELLULAR
+const val TEST_INTERFACE_NAME = "testInterfaceName"
+
+@RunWith(JUnit4::class)
+@SuppressLint("NewApi") // Uses hidden APIs, which the linter would identify as missing APIs.
+class TestableNetworkCallbackTest {
+ private lateinit var mCallback: TestableNetworkCallback
+
+ private fun makeHasNetwork(netId: Int) = object : TestableNetworkCallback.HasNetwork {
+ override val network: Network = Network(netId)
+ }
+
+ @Before
+ fun setUp() {
+ mCallback = TestableNetworkCallback()
+ }
+
+ @Test
+ fun testLastAvailableNetwork() {
+ // Make sure there is no last available network at first, then the last available network
+ // is returned after onAvailable is called.
+ val net2097 = Network(2097)
+ assertNull(mCallback.lastAvailableNetwork)
+ mCallback.onAvailable(net2097)
+ assertEquals(mCallback.lastAvailableNetwork, net2097)
+
+ // Make sure calling onCapsChanged/onLinkPropertiesChanged don't affect the last available
+ // network.
+ mCallback.onCapabilitiesChanged(net2097, NetworkCapabilities())
+ mCallback.onLinkPropertiesChanged(net2097, LinkProperties())
+ assertEquals(mCallback.lastAvailableNetwork, net2097)
+
+ // Make sure onLost clears the last available network.
+ mCallback.onLost(net2097)
+ assertNull(mCallback.lastAvailableNetwork)
+
+ // Do the same but with a different network after onLost : make sure the last available
+ // network is the new one, not the original one.
+ val net2098 = Network(2098)
+ mCallback.onAvailable(net2098)
+ mCallback.onCapabilitiesChanged(net2098, NetworkCapabilities())
+ mCallback.onLinkPropertiesChanged(net2098, LinkProperties())
+ assertEquals(mCallback.lastAvailableNetwork, net2098)
+
+ // Make sure onAvailable changes the last available network even if onLost was not called.
+ val net2099 = Network(2099)
+ mCallback.onAvailable(net2099)
+ assertEquals(mCallback.lastAvailableNetwork, net2099)
+
+ // For legacy reasons, lastAvailableNetwork is null as soon as any is lost, not necessarily
+ // the last available one. Check that behavior.
+ mCallback.onLost(net2098)
+ assertNull(mCallback.lastAvailableNetwork)
+
+ // Make sure that losing the really last available one still results in null.
+ mCallback.onLost(net2099)
+ assertNull(mCallback.lastAvailableNetwork)
+
+ // Make sure multiple onAvailable in a row then onLost still results in null.
+ mCallback.onAvailable(net2097)
+ mCallback.onAvailable(net2098)
+ mCallback.onAvailable(net2099)
+ mCallback.onLost(net2097)
+ assertNull(mCallback.lastAvailableNetwork)
+ }
+
+ @Test
+ fun testAssertNoCallback() {
+ mCallback.assertNoCallback(SHORT_TIMEOUT_MS)
+ mCallback.onAvailable(Network(100))
+ assertFails { mCallback.assertNoCallback(SHORT_TIMEOUT_MS) }
+ }
+
+ @Test
+ fun testAssertNoCallbackThat() {
+ val net = Network(101)
+ mCallback.assertNoCallbackThat { it is Available }
+ mCallback.onAvailable(net)
+ // Expect no blocked status change. Receive other callback does not fail the test.
+ mCallback.assertNoCallbackThat { it is BlockedStatus }
+ mCallback.onBlockedStatusChanged(net, true)
+ assertFails { mCallback.assertNoCallbackThat { it is BlockedStatus } }
+ mCallback.onBlockedStatusChanged(net, false)
+ mCallback.onCapabilitiesChanged(net, NetworkCapabilities())
+ assertFails { mCallback.assertNoCallbackThat { it is CapabilitiesChanged } }
+ }
+
+ @Test
+ fun testCapabilitiesWithAndWithout() {
+ val net = Network(101)
+ val matcher = makeHasNetwork(101)
+ val meteredNc = NetworkCapabilities()
+ val unmeteredNc = NetworkCapabilities().addCapability(NOT_METERED)
+ // Check that expecting caps (with or without) fails when no callback has been received.
+ assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+ assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+
+ // Add NOT_METERED and check that With succeeds and Without fails.
+ mCallback.onCapabilitiesChanged(net, unmeteredNc)
+ mCallback.expectCapabilitiesWith(NOT_METERED, matcher)
+ mCallback.onCapabilitiesChanged(net, unmeteredNc)
+ assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+
+ // Don't add NOT_METERED and check that With fails and Without succeeds.
+ mCallback.onCapabilitiesChanged(net, meteredNc)
+ assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+ mCallback.onCapabilitiesChanged(net, meteredNc)
+ mCallback.expectCapabilitiesWithout(NOT_METERED, matcher)
+ }
+
+ @Test
+ fun testExpectWithPredicate() {
+ val net = Network(193)
+ val netCaps = NetworkCapabilities().addTransportType(CELLULAR)
+ // Check that expecting callbackThat anything fails when no callback has been received.
+ assertFails { mCallback.expect<CallbackEntry>(timeoutMs = SHORT_TIMEOUT_MS) { true } }
+
+ // Basic test for true and false
+ mCallback.onAvailable(net)
+ mCallback.expect<Available> { true }
+ mCallback.onAvailable(net)
+ assertFails { mCallback.expect<CallbackEntry>(timeoutMs = SHORT_TIMEOUT_MS) { false } }
+
+ // Try a positive and a negative case
+ mCallback.onBlockedStatusChanged(net, true)
+ mCallback.expect<CallbackEntry> { cb -> cb is BlockedStatus && cb.blocked }
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ assertFails { mCallback.expect<CallbackEntry>(timeoutMs = SHORT_TIMEOUT_MS) { cb ->
+ cb is CapabilitiesChanged && cb.caps.hasTransport(WIFI)
+ } }
+ }
+
+ @Test
+ fun testCapabilitiesThat() {
+ val net = Network(101)
+ val netCaps = NetworkCapabilities().addCapability(NOT_METERED).addTransportType(WIFI)
+ // Check that expecting capabilitiesThat anything fails when no callback has been received.
+ assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { true } }
+
+ // Basic test for true and false
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ mCallback.expectCapabilitiesThat(net) { true }
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { false } }
+
+ // Try a positive and a negative case
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ mCallback.expectCapabilitiesThat(net) { caps ->
+ caps.hasCapability(NOT_METERED) &&
+ caps.hasTransport(WIFI) &&
+ !caps.hasTransport(CELLULAR)
+ }
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { caps ->
+ caps.hasTransport(CELLULAR)
+ } }
+
+ // Try a matching callback on the wrong network
+ mCallback.onCapabilitiesChanged(net, netCaps)
+ assertFails { mCallback.expectCapabilitiesThat(Network(100), SHORT_TIMEOUT_MS) { true } }
+ }
+
+ @Test
+ fun testLinkPropertiesThat() {
+ val net = Network(112)
+ val linkAddress = LinkAddress("fe80::ace:d00d/64")
+ val mtu = 1984
+ val linkProps = LinkProperties().apply {
+ this.mtu = mtu
+ interfaceName = TEST_INTERFACE_NAME
+ addLinkAddress(linkAddress)
+ }
+
+ // Check that expecting linkPropsThat anything fails when no callback has been received.
+ assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { true } }
+
+ // Basic test for true and false
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ mCallback.expectLinkPropertiesThat(net) { true }
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { false } }
+
+ // Try a positive and negative case
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ mCallback.expectLinkPropertiesThat(net) { lp ->
+ lp.interfaceName == TEST_INTERFACE_NAME &&
+ lp.linkAddresses.contains(linkAddress) &&
+ lp.mtu == mtu
+ }
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { lp ->
+ lp.interfaceName != TEST_INTERFACE_NAME
+ } }
+
+ // Try a matching callback on the wrong network
+ mCallback.onLinkPropertiesChanged(net, linkProps)
+ assertFails { mCallback.expectLinkPropertiesThat(Network(114), SHORT_TIMEOUT_MS) { lp ->
+ lp.interfaceName == TEST_INTERFACE_NAME
+ } }
+ }
+
+ @Test
+ fun testExpect() {
+ val net = Network(103)
+ // Test expectCallback fails when nothing was sent.
+ assertFails { mCallback.expect<BlockedStatus>(net, SHORT_TIMEOUT_MS) }
+
+ // Test onAvailable is seen and can be expected
+ mCallback.onAvailable(net)
+ mCallback.expect<Available>(net, SHORT_TIMEOUT_MS)
+
+ // Test onAvailable won't return calls with a different network
+ mCallback.onAvailable(Network(106))
+ assertFails { mCallback.expect<Available>(net, SHORT_TIMEOUT_MS) }
+
+ // Test onAvailable won't return calls with a different callback
+ mCallback.onAvailable(net)
+ assertFails { mCallback.expect<BlockedStatus>(net, SHORT_TIMEOUT_MS) }
+ }
+
+ @Test
+ fun testAllExpectOverloads() {
+ // This test should never run, it only checks that all overloads exist and build
+ assumeTrue(false)
+ val hn = object : TestableNetworkCallback.HasNetwork { override val network = ANY_NETWORK }
+
+ // Method with all arguments (version that takes a Network)
+ mCallback.expect(AVAILABLE, ANY_NETWORK, 10, "error") { true }
+
+ // Java overloads omitting one argument. One line for omitting each argument, in positional
+ // order. Versions that take a Network.
+ mCallback.expect(AVAILABLE, 10, "error") { true }
+ mCallback.expect(AVAILABLE, ANY_NETWORK, "error") { true }
+ mCallback.expect(AVAILABLE, ANY_NETWORK, 10) { true }
+ mCallback.expect(AVAILABLE, ANY_NETWORK, 10, "error")
+
+ // Java overloads for omitting two arguments. One line for omitting each pair of arguments.
+ // Versions that take a Network.
+ mCallback.expect(AVAILABLE, "error") { true }
+ mCallback.expect(AVAILABLE, 10) { true }
+ mCallback.expect(AVAILABLE, 10, "error")
+ mCallback.expect(AVAILABLE, ANY_NETWORK) { true }
+ mCallback.expect(AVAILABLE, ANY_NETWORK, "error")
+ mCallback.expect(AVAILABLE, ANY_NETWORK, 10)
+
+ // Java overloads for omitting three arguments. One line for each remaining argument.
+ // Versions that take a Network.
+ mCallback.expect(AVAILABLE) { true }
+ mCallback.expect(AVAILABLE, "error")
+ mCallback.expect(AVAILABLE, 10)
+ mCallback.expect(AVAILABLE, ANY_NETWORK)
+
+ // Java overload for omitting all four arguments.
+ mCallback.expect(AVAILABLE)
+
+ // Same orders as above, but versions that take a HasNetwork. Except overloads that
+ // were already tested because they omitted the Network argument
+ mCallback.expect(AVAILABLE, hn, 10, "error") { true }
+ mCallback.expect(AVAILABLE, hn, "error") { true }
+ mCallback.expect(AVAILABLE, hn, 10) { true }
+ mCallback.expect(AVAILABLE, hn, 10, "error")
+
+ mCallback.expect(AVAILABLE, hn) { true }
+ mCallback.expect(AVAILABLE, hn, "error")
+ mCallback.expect(AVAILABLE, hn, 10)
+
+ mCallback.expect(AVAILABLE, hn)
+
+ // Same as above but for reified versions.
+ mCallback.expect<Available>(ANY_NETWORK, 10, "error") { true }
+ mCallback.expect<Available>(timeoutMs = 10, errorMsg = "error") { true }
+ mCallback.expect<Available>(network = ANY_NETWORK, errorMsg = "error") { true }
+ mCallback.expect<Available>(network = ANY_NETWORK, timeoutMs = 10) { true }
+ mCallback.expect<Available>(network = ANY_NETWORK, timeoutMs = 10, errorMsg = "error")
+
+ mCallback.expect<Available>(errorMsg = "error") { true }
+ mCallback.expect<Available>(timeoutMs = 10) { true }
+ mCallback.expect<Available>(timeoutMs = 10, errorMsg = "error")
+ mCallback.expect<Available>(network = ANY_NETWORK) { true }
+ mCallback.expect<Available>(network = ANY_NETWORK, errorMsg = "error")
+ mCallback.expect<Available>(network = ANY_NETWORK, timeoutMs = 10)
+
+ mCallback.expect<Available> { true }
+ mCallback.expect<Available>(errorMsg = "error")
+ mCallback.expect<Available>(timeoutMs = 10)
+ mCallback.expect<Available>(network = ANY_NETWORK)
+ mCallback.expect<Available>()
+
+ mCallback.expect<Available>(hn, 10, "error") { true }
+ mCallback.expect<Available>(network = hn, errorMsg = "error") { true }
+ mCallback.expect<Available>(network = hn, timeoutMs = 10) { true }
+ mCallback.expect<Available>(network = hn, timeoutMs = 10, errorMsg = "error")
+
+ mCallback.expect<Available>(network = hn) { true }
+ mCallback.expect<Available>(network = hn, errorMsg = "error")
+ mCallback.expect<Available>(network = hn, timeoutMs = 10)
+
+ mCallback.expect<Available>(network = hn)
+ }
+
+ @Test
+ fun testPoll() {
+ assertNull(mCallback.poll(SHORT_TIMEOUT_MS))
+ TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1,
+ threadTransform = { cb -> cb.createLinkedCopy() }, spec = """
+ sleep; onAvailable(133) | poll(2) = Available(133) time 1..4
+ | poll(1) = null
+ onCapabilitiesChanged(108) | poll(1) = CapabilitiesChanged(108) time 0..3
+ onBlockedStatus(199) | poll(1) = BlockedStatus(199) time 0..3
+ """)
+ }
+
+ @Test
+ fun testPollOrThrow() {
+ assertFails { mCallback.pollOrThrow(SHORT_TIMEOUT_MS) }
+ TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1,
+ threadTransform = { cb -> cb.createLinkedCopy() }, spec = """
+ sleep; onAvailable(133) | pollOrThrow(2) = Available(133) time 1..4
+ | pollOrThrow(1) fails
+ onCapabilitiesChanged(108) | pollOrThrow(1) = CapabilitiesChanged(108) time 0..3
+ onBlockedStatus(199) | pollOrThrow(1) = BlockedStatus(199) time 0..3
+ """)
+ }
+
+ @Test
+ fun testEventuallyExpect() {
+ // TODO: Current test does not verify the inline one. Also verify the behavior after
+ // aligning two eventuallyExpect()
+ val net1 = Network(100)
+ val net2 = Network(101)
+ mCallback.onAvailable(net1)
+ mCallback.onCapabilitiesChanged(net1, NetworkCapabilities())
+ mCallback.onLinkPropertiesChanged(net1, LinkProperties())
+ mCallback.eventuallyExpect(LINK_PROPERTIES_CHANGED) {
+ net1.equals(it.network)
+ }
+ // No further new callback. Expect no callback.
+ assertFails { mCallback.eventuallyExpect(LINK_PROPERTIES_CHANGED, SHORT_TIMEOUT_MS) }
+
+ // Verify no predicate set.
+ mCallback.onAvailable(net2)
+ mCallback.onLinkPropertiesChanged(net2, LinkProperties())
+ mCallback.onBlockedStatusChanged(net1, false)
+ mCallback.eventuallyExpect(BLOCKED_STATUS) { net1.equals(it.network) }
+ // Verify no callback received if the callback does not happen.
+ assertFails { mCallback.eventuallyExpect(LOSING, SHORT_TIMEOUT_MS) }
+ }
+
+ @Test
+ fun testEventuallyExpectOnMultiThreads() {
+ TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1,
+ threadTransform = { cb -> cb.createLinkedCopy() }, spec = """
+ onAvailable(100) | eventually(CapabilitiesChanged(100), 1) fails
+ sleep ; onCapabilitiesChanged(100) | eventually(CapabilitiesChanged(100), 3)
+ onAvailable(101) ; onBlockedStatus(101) | eventually(BlockedStatus(100), 2) fails
+ onSuspended(100) ; sleep ; onLost(100) | eventually(Lost(100), 3)
+ """)
+ }
+}
+
+private object TNCInterpreter : ConcurrentInterpreter<TestableNetworkCallback>(interpretTable)
+
+val EntryList = CallbackEntry::class.sealedSubclasses.map { it.simpleName }.joinToString("|")
+private fun callbackEntryFromString(name: String): KClass<out CallbackEntry> {
+ return CallbackEntry::class.sealedSubclasses.first { it.simpleName == name }
+}
+
+@SuppressLint("NewApi") // Uses hidden APIs, which the linter would identify as missing APIs.
+private val interpretTable = listOf<InterpretMatcher<TestableNetworkCallback>>(
+ // Interpret "Available(xx)" as "call to onAvailable with netId xx", and likewise for
+ // all callback types. This is implemented above by enumerating the subclasses of
+ // CallbackEntry and reading their simpleName.
+ Regex("""(.*)\s+=\s+($EntryList)\((\d+)\)""") to { i, cb, t ->
+ val record = i.interpret(t.strArg(1), cb)
+ assertTrue(callbackEntryFromString(t.strArg(2)).isInstance(record))
+ // Strictly speaking testing for is CallbackEntry is useless as it's been tested above
+ // but the compiler can't figure things out from the isInstance call. It does understand
+ // from the assertTrue(is CallbackEntry) that this is true, which allows to access
+ // the 'network' member below.
+ assertTrue(record is CallbackEntry)
+ assertEquals(record.network.netId, t.intArg(3))
+ },
+ // Interpret "onAvailable(xx)" as calling "onAvailable" with a netId of xx, and likewise for
+ // all callback types. NetworkCapabilities and LinkProperties just get an empty object
+ // as their argument. Losing gets the default linger timer. Blocked gets false.
+ Regex("""on($EntryList)\((\d+)\)""") to { i, cb, t ->
+ val net = Network(t.intArg(2))
+ when (t.strArg(1)) {
+ "Available" -> cb.onAvailable(net)
+ // PreCheck not used in tests. Add it here if it becomes useful.
+ "CapabilitiesChanged" -> cb.onCapabilitiesChanged(net, NetworkCapabilities())
+ "LinkPropertiesChanged" -> cb.onLinkPropertiesChanged(net, LinkProperties())
+ "Suspended" -> cb.onNetworkSuspended(net)
+ "Resumed" -> cb.onNetworkResumed(net)
+ "Losing" -> cb.onLosing(net, DEFAULT_LINGER_DELAY_MS)
+ "Lost" -> cb.onLost(net)
+ "Unavailable" -> cb.onUnavailable()
+ "BlockedStatus" -> cb.onBlockedStatusChanged(net, false)
+ else -> fail("Unknown callback type")
+ }
+ },
+ Regex("""poll\((\d+)\)""") to { i, cb, t -> cb.poll(t.timeArg(1)) },
+ Regex("""pollOrThrow\((\d+)\)""") to { i, cb, t -> cb.pollOrThrow(t.timeArg(1)) },
+ // Interpret "eventually(Available(xx), timeout)" as calling eventuallyExpect that expects
+ // CallbackEntry.AVAILABLE with netId of xx within timeout*INTERPRET_TIME_UNIT timeout, and
+ // likewise for all callback types.
+ Regex("""eventually\(($EntryList)\((\d+)\),\s+(\d+)\)""") to { i, cb, t ->
+ val net = Network(t.intArg(2))
+ val timeout = t.timeArg(3)
+ when (t.strArg(1)) {
+ "Available" -> cb.eventuallyExpect(AVAILABLE, timeout) { net == it.network }
+ "Suspended" -> cb.eventuallyExpect(SUSPENDED, timeout) { net == it.network }
+ "Resumed" -> cb.eventuallyExpect(RESUMED, timeout) { net == it.network }
+ "Losing" -> cb.eventuallyExpect(LOSING, timeout) { net == it.network }
+ "Lost" -> cb.eventuallyExpect(LOST, timeout) { net == it.network }
+ "Unavailable" -> cb.eventuallyExpect(UNAVAILABLE, timeout) { net == it.network }
+ "BlockedStatus" -> cb.eventuallyExpect(BLOCKED_STATUS, timeout) { net == it.network }
+ "CapabilitiesChanged" ->
+ cb.eventuallyExpect(NETWORK_CAPS_UPDATED, timeout) { net == it.network }
+ "LinkPropertiesChanged" ->
+ cb.eventuallyExpect(LINK_PROPERTIES_CHANGED, timeout) { net == it.network }
+ else -> fail("Unknown callback type")
+ }
+ }
+)
diff --git a/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTestJava.java b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTestJava.java
new file mode 100644
index 0000000..4570d0a
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTestJava.java
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils;
+
+import static com.android.testutils.RecorderCallback.CallbackEntry.AVAILABLE;
+import static com.android.testutils.TestableNetworkCallbackKt.anyNetwork;
+
+import static org.junit.Assume.assumeTrue;
+
+import org.junit.Test;
+
+public class TestableNetworkCallbackTestJava {
+ @Test
+ void testAllExpectOverloads() {
+ // This test should never run, it only checks that all overloads exist and build
+ assumeTrue(false);
+ final TestableNetworkCallback callback = new TestableNetworkCallback();
+ TestableNetworkCallback.HasNetwork hn = TestableNetworkCallbackKt::anyNetwork;
+
+ // Method with all arguments (version that takes a Network)
+ callback.expect(AVAILABLE, anyNetwork(), 10, "error", cb -> true);
+
+ // Overloads omitting one argument. One line for omitting each argument, in positional
+ // order. Versions that take a Network.
+ callback.expect(AVAILABLE, 10, "error", cb -> true);
+ callback.expect(AVAILABLE, anyNetwork(), "error", cb -> true);
+ callback.expect(AVAILABLE, anyNetwork(), 10, cb -> true);
+ callback.expect(AVAILABLE, anyNetwork(), 10, "error");
+
+ // Overloads for omitting two arguments. One line for omitting each pair of arguments.
+ // Versions that take a Network.
+ callback.expect(AVAILABLE, "error", cb -> true);
+ callback.expect(AVAILABLE, 10, cb -> true);
+ callback.expect(AVAILABLE, 10, "error");
+ callback.expect(AVAILABLE, anyNetwork(), cb -> true);
+ callback.expect(AVAILABLE, anyNetwork(), "error");
+ callback.expect(AVAILABLE, anyNetwork(), 10);
+
+ // Overloads for omitting three arguments. One line for each remaining argument.
+ // Versions that take a Network.
+ callback.expect(AVAILABLE, cb -> true);
+ callback.expect(AVAILABLE, "error");
+ callback.expect(AVAILABLE, 10);
+ callback.expect(AVAILABLE, anyNetwork());
+
+ // Java overload for omitting all four arguments.
+ callback.expect(AVAILABLE);
+
+ // Same orders as above, but versions that take a HasNetwork. Except overloads that
+ // were already tested because they omitted the Network argument
+ callback.expect(AVAILABLE, hn, 10, "error", cb -> true);
+ callback.expect(AVAILABLE, hn, "error", cb -> true);
+ callback.expect(AVAILABLE, hn, 10, cb -> true);
+ callback.expect(AVAILABLE, hn, 10, "error");
+
+ callback.expect(AVAILABLE, hn, cb -> true);
+ callback.expect(AVAILABLE, hn, "error");
+ callback.expect(AVAILABLE, hn, 10);
+
+ callback.expect(AVAILABLE, hn);
+ }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt b/staticlibs/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt
index cbdc017..9e72f4b 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt
@@ -13,7 +13,7 @@
typealias InterpretMatcher<T> = Pair<Regex, (ConcurrentInterpreter<T>, T, MatchResult) -> Any?>
// The default unit of time for interpreted tests
-val INTERPRET_TIME_UNIT = 40L // ms
+const val INTERPRET_TIME_UNIT = 60L // ms
/**
* A small interpreter for testing parallel code.
@@ -40,10 +40,7 @@
* Some expressions already exist by default and can be used by all interpreters. Refer to
* getDefaultInstructions() below for a list and documentation.
*/
-open class ConcurrentInterpreter<T>(
- localInterpretTable: List<InterpretMatcher<T>>,
- val interpretTimeUnit: Long = INTERPRET_TIME_UNIT
-) {
+open class ConcurrentInterpreter<T>(localInterpretTable: List<InterpretMatcher<T>>) {
private val interpretTable: List<InterpretMatcher<T>> =
localInterpretTable + getDefaultInstructions()
// The last time the thread became blocked, with base System.currentTimeMillis(). This should
@@ -211,7 +208,7 @@
},
// Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units.
Regex("""sleep(\((\d+)\))?""") to { i, t, r ->
- SystemClock.sleep(if (r.strArg(2).isEmpty()) i.interpretTimeUnit else r.timeArg(2))
+ SystemClock.sleep(if (r.strArg(2).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(2))
},
Regex("""(.*)\s*fails""") to { i, t, r ->
assertFails { i.interpret(r.strArg(1), t) }
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/ConnectUtil.kt b/staticlibs/testutils/devicetests/com/android/testutils/ConnectUtil.kt
index fc951d8..7b5ad01 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/ConnectUtil.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/ConnectUtil.kt
@@ -63,6 +63,7 @@
try {
val connInfo = wifiManager.connectionInfo
+ Log.d(TAG, "connInfo=" + connInfo)
if (connInfo == null || connInfo.networkId == -1) {
clearWifiBlocklist()
val pfd = getInstrumentation().uiAutomation.executeShellCommand("svc wifi enable")
@@ -75,7 +76,7 @@
timeoutMs = WIFI_CONNECT_TIMEOUT_MS)
assertNotNull(cb, "Could not connect to a wifi access point within " +
- "$WIFI_CONNECT_INTERVAL_MS ms. Check that the test device has a wifi network " +
+ "$WIFI_CONNECT_TIMEOUT_MS ms. Check that the test device has a wifi network " +
"configured, and that the test access point is functioning properly.")
return cb.network
} finally {
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt b/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt
new file mode 100644
index 0000000..3d98cc3
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt
@@ -0,0 +1,176 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.Manifest.permission.READ_DEVICE_CONFIG
+import android.Manifest.permission.WRITE_DEVICE_CONFIG
+import android.provider.DeviceConfig
+import android.util.Log
+import com.android.modules.utils.build.SdkLevel
+import com.android.testutils.FunctionalUtils.ThrowingRunnable
+import org.junit.rules.TestRule
+import org.junit.runner.Description
+import org.junit.runners.model.Statement
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.Executor
+import java.util.concurrent.TimeUnit
+
+private val TAG = DeviceConfigRule::class.simpleName
+
+private const val TIMEOUT_MS = 20_000L
+
+/**
+ * A [TestRule] that helps set [DeviceConfig] for tests and clean up the test configuration
+ * automatically on teardown.
+ *
+ * The rule can also optionally retry tests when they fail following an external change of
+ * DeviceConfig before S; this typically happens because device config flags are synced while the
+ * test is running, and DisableConfigSyncTargetPreparer is only usable starting from S.
+ *
+ * @param retryCountBeforeSIfConfigChanged if > 0, when the test fails before S, check if
+ * the configs that were set through this rule were changed, and retry the test
+ * up to the specified number of times if yes.
+ */
+class DeviceConfigRule @JvmOverloads constructor(
+ val retryCountBeforeSIfConfigChanged: Int = 0
+) : TestRule {
+ // Maps (namespace, key) -> value
+ private val originalConfig = mutableMapOf<Pair<String, String>, String?>()
+ private val usedConfig = mutableMapOf<Pair<String, String>, String?>()
+
+ /**
+ * Actions to be run after cleanup of the config, for the current test only.
+ */
+ private val currentTestCleanupActions = mutableListOf<ThrowingRunnable>()
+
+ override fun apply(base: Statement, description: Description): Statement {
+ return TestValidationUrlStatement(base, description)
+ }
+
+ private inner class TestValidationUrlStatement(
+ private val base: Statement,
+ private val description: Description
+ ) : Statement() {
+ override fun evaluate() {
+ var retryCount = if (SdkLevel.isAtLeastS()) 1 else retryCountBeforeSIfConfigChanged + 1
+ while (retryCount > 0) {
+ retryCount--
+ tryTest {
+ base.evaluate()
+ // Can't use break/return out of a loop here because this is a tryTest lambda,
+ // so set retryCount to exit instead
+ retryCount = 0
+ }.catch<Throwable> { e -> // junit AssertionFailedError does not extend Exception
+ if (retryCount == 0) throw e
+ usedConfig.forEach { (key, value) ->
+ val currentValue = runAsShell(READ_DEVICE_CONFIG) {
+ DeviceConfig.getProperty(key.first, key.second)
+ }
+ if (currentValue != value) {
+ Log.w(TAG, "Test failed with unexpected device config change, retrying")
+ return@catch
+ }
+ }
+ throw e
+ } cleanupStep {
+ runAsShell(WRITE_DEVICE_CONFIG) {
+ originalConfig.forEach { (key, value) ->
+ DeviceConfig.setProperty(
+ key.first, key.second, value, false /* makeDefault */)
+ }
+ }
+ } cleanupStep {
+ originalConfig.clear()
+ usedConfig.clear()
+ } cleanup {
+ // Fold all cleanup actions into cleanup steps of an empty tryTest, so they are
+ // all run even if exceptions are thrown, and exceptions are reported properly.
+ currentTestCleanupActions.fold(tryTest { }) {
+ tryBlock, action -> tryBlock.cleanupStep { action.run() }
+ }.cleanup {
+ currentTestCleanupActions.clear()
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Set a configuration key/value. After the test case ends, it will be restored to the value it
+ * had when this method was first called.
+ */
+ fun setConfig(namespace: String, key: String, value: String?): String? {
+ Log.i(TAG, "Setting config \"$key\" to \"$value\"")
+ val readWritePermissions = arrayOf(READ_DEVICE_CONFIG, WRITE_DEVICE_CONFIG)
+
+ val keyPair = Pair(namespace, key)
+ val existingValue = runAsShell(*readWritePermissions) {
+ DeviceConfig.getProperty(namespace, key)
+ }
+ if (!originalConfig.containsKey(keyPair)) {
+ originalConfig[keyPair] = existingValue
+ }
+ usedConfig[keyPair] = value
+ if (existingValue == value) {
+ // Already the correct value. There may be a race if a change is already in flight,
+ // but if multiple threads update the config there is no way to fix that anyway.
+ Log.i(TAG, "\"$key\" already had value \"$value\"")
+ return value
+ }
+
+ val future = CompletableFuture<String>()
+ val listener = DeviceConfig.OnPropertiesChangedListener {
+ // The listener receives updates for any change to any key, so don't react to
+ // changes that do not affect the relevant key
+ if (!it.keyset.contains(key)) return@OnPropertiesChangedListener
+ // "null" means absent in DeviceConfig : there is no such thing as a present but
+ // null value, so the following works even if |value| is null.
+ if (it.getString(key, null) == value) {
+ future.complete(value)
+ }
+ }
+
+ return tryTest {
+ runAsShell(*readWritePermissions) {
+ DeviceConfig.addOnPropertiesChangedListener(
+ DeviceConfig.NAMESPACE_CONNECTIVITY,
+ inlineExecutor,
+ listener)
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_CONNECTIVITY,
+ key,
+ value,
+ false /* makeDefault */)
+ // Don't drop the permission until the config is applied, just in case
+ future.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)
+ }.also {
+ Log.i(TAG, "Config \"$key\" successfully set to \"$value\"")
+ }
+ } cleanup {
+ DeviceConfig.removeOnPropertiesChangedListener(listener)
+ }
+ }
+
+ private val inlineExecutor get() = Executor { r -> r.run() }
+
+ /**
+ * Add an action to be run after config cleanup when the current test case ends.
+ */
+ fun runAfterNextCleanup(action: ThrowingRunnable) {
+ currentTestCleanupActions.add(action)
+ }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java b/staticlibs/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java
index ea89eda..ce55fdc 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java
@@ -20,6 +20,7 @@
import android.text.TextUtils;
import android.util.Pair;
+import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -128,7 +129,7 @@
final Pair<Integer, Integer> v1 = getMajorMinorVersion(s1);
final Pair<Integer, Integer> v2 = getMajorMinorVersion(s2);
- if (v1.first == v2.first) {
+ if (Objects.equals(v1.first, v2.first)) {
return Integer.compare(v1.second, v2.second);
} else {
return Integer.compare(v1.first, v2.first);
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt
index 5ae2439..b743b6c 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt
@@ -111,12 +111,12 @@
val tnm: TestNetworkManager,
val lp: LinkProperties?,
setupTimeoutMs: Long
-) {
+) : TestableNetworkCallback.HasNetwork {
private val cm = context.getSystemService(ConnectivityManager::class.java)
private val binder = Binder()
private val networkCallback: NetworkCallback
- val network: Network
+ override val network: Network
val testIface: TestNetworkInterface
init {
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
index dffdbe8..b84f9a6 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
@@ -36,15 +36,12 @@
import kotlin.reflect.KClass
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
-import kotlin.test.assertTrue
import kotlin.test.fail
object NULL_NETWORK : Network(-1)
object ANY_NETWORK : Network(-2)
fun anyNetwork() = ANY_NETWORK
-private val Int.capabilityName get() = NetworkCapabilities.capabilityNameOf(this)
-
open class RecorderCallback private constructor(
private val backingRecord: ArrayTrackRecord<CallbackEntry>
) : NetworkCallback() {
@@ -168,16 +165,22 @@
}
}
-private const val DEFAULT_TIMEOUT = 200L // ms
+private const val DEFAULT_TIMEOUT = 30_000L // ms
+private const val DEFAULT_NO_CALLBACK_TIMEOUT = 200L // ms
open class TestableNetworkCallback private constructor(
src: TestableNetworkCallback?,
- val defaultTimeoutMs: Long = DEFAULT_TIMEOUT
+ val defaultTimeoutMs: Long = DEFAULT_TIMEOUT,
+ val defaultNoCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT
) : RecorderCallback(src) {
@JvmOverloads
- constructor(timeoutMs: Long = DEFAULT_TIMEOUT): this(null, timeoutMs)
+ constructor(
+ timeoutMs: Long = DEFAULT_TIMEOUT,
+ noCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT
+ ): this(null, timeoutMs, noCallbackTimeoutMs)
- fun createLinkedCopy() = TestableNetworkCallback(this, defaultTimeoutMs)
+ fun createLinkedCopy() = TestableNetworkCallback(
+ this, defaultTimeoutMs, defaultNoCallbackTimeoutMs)
// The last available network, or null if any network was lost since the last call to
// onAvailable. TODO : fix this by fixing the tests that rely on this behavior
@@ -187,14 +190,163 @@
else -> null
}
- fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackEntry {
- return history.poll(timeoutMs) ?: fail("Did not receive callback after ${timeoutMs}ms")
- }
+ /**
+ * Get the next callback or null if timeout.
+ *
+ * With no argument, this method waits out the default timeout. To wait forever, pass
+ * Long.MAX_VALUE.
+ */
+ @JvmOverloads
+ fun poll(timeoutMs: Long = defaultTimeoutMs): CallbackEntry? = history.poll(timeoutMs)
+
+ /**
+ * Get the next callback or throw if timeout.
+ *
+ * With no argument, this method waits out the default timeout. To wait forever, pass
+ * Long.MAX_VALUE.
+ */
+ @JvmOverloads
+ fun pollOrThrow(
+ timeoutMs: Long = defaultTimeoutMs,
+ errorMsg: String = "Did not receive callback after $timeoutMs"
+ ): CallbackEntry = poll(timeoutMs) ?: fail(errorMsg)
+
+ /*****
+ * expect family of methods.
+ * These methods fetch the next callback and assert it matches the conditions : type,
+ * passed predicate. If no callback is received within the timeout, these methods fail.
+ */
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: Network = ANY_NETWORK,
+ timeoutMs: Long = defaultTimeoutMs,
+ errorMsg: String? = null,
+ test: (T) -> Boolean = { true }
+ ) = expect<CallbackEntry>(network, timeoutMs, errorMsg) {
+ test(it as? T ?: fail("Expected callback ${type.simpleName}, got $it"))
+ } as T
+
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: HasNetwork,
+ timeoutMs: Long = defaultTimeoutMs,
+ errorMsg: String? = null,
+ test: (T) -> Boolean = { true }
+ ) = expect(type, network.network, timeoutMs, errorMsg, test)
+
+ // Java needs an explicit overload to let it omit arguments in the middle, so define these
+ // here. Note that @JvmOverloads give us the versions without the last arguments too, so
+ // there is no need to explicitly define versions without the test predicate.
+ // Without |network|
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ timeoutMs: Long,
+ errorMsg: String?,
+ test: (T) -> Boolean = { true }
+ ) = expect(type, ANY_NETWORK, timeoutMs, errorMsg, test)
+
+ // Without |timeout|, in Network and HasNetwork versions
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: Network,
+ errorMsg: String?,
+ test: (T) -> Boolean = { true }
+ ) = expect(type, network, defaultTimeoutMs, errorMsg, test)
+
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: HasNetwork,
+ errorMsg: String?,
+ test: (T) -> Boolean = { true }
+ ) = expect(type, network.network, defaultTimeoutMs, errorMsg, test)
+
+ // Without |errorMsg|, in Network and HasNetwork versions
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: Network,
+ timeoutMs: Long,
+ test: (T) -> Boolean
+ ) = expect(type, network, timeoutMs, null, test)
+
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: HasNetwork,
+ timeoutMs: Long,
+ test: (T) -> Boolean
+ ) = expect(type, network.network, timeoutMs, null, test)
+
+ // Without |network| or |timeout|
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ errorMsg: String?,
+ test: (T) -> Boolean = { true }
+ ) = expect(type, ANY_NETWORK, defaultTimeoutMs, errorMsg, test)
+
+ // Without |network| or |errorMsg|
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ timeoutMs: Long,
+ test: (T) -> Boolean = { true }
+ ) = expect(type, ANY_NETWORK, timeoutMs, null, test)
+
+ // Without |timeout| or |errorMsg|, in Network and HasNetwork versions
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: Network,
+ test: (T) -> Boolean
+ ) = expect(type, network, defaultTimeoutMs, null, test)
+
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ network: HasNetwork,
+ test: (T) -> Boolean
+ ) = expect(type, network.network, defaultTimeoutMs, null, test)
+
+ // Without |network| or |timeout| or |errorMsg|
+ @JvmOverloads
+ fun <T : CallbackEntry> expect(
+ type: KClass<T>,
+ test: (T) -> Boolean
+ ) = expect(type, ANY_NETWORK, defaultTimeoutMs, null, test)
+
+ // Kotlin reified versions. Don't call methods above, or the predicate would need to be noinline
+ inline fun <reified T : CallbackEntry> expect(
+ network: Network = ANY_NETWORK,
+ timeoutMs: Long = defaultTimeoutMs,
+ errorMsg: String? = null,
+ test: (T) -> Boolean = { true }
+ ) = pollOrThrow(timeoutMs).also {
+ if (it !is T) fail("Expected callback ${T::class.simpleName}, got $it")
+ if (ANY_NETWORK !== network && it.network != network) {
+ fail("Expected network $network for callback : $it")
+ }
+ if (!test(it)) {
+ fail("${errorMsg ?: "Callback doesn't match predicate"} : $it")
+ }
+ } as T
+
+ inline fun <reified T : CallbackEntry> expect(
+ network: HasNetwork,
+ timeoutMs: Long = defaultTimeoutMs,
+ errorMsg: String? = null,
+ test: (T) -> Boolean = { true }
+ ) = expect(network.network, timeoutMs, errorMsg, test)
// Make open for use in ConnectivityServiceTest which is the only one knowing its handlers.
// TODO : remove the necessity to overload this, remove the open qualifier, and give a
// default argument to assertNoCallback instead, possibly with @JvmOverloads if necessary.
- open fun assertNoCallback() = assertNoCallback(defaultTimeoutMs)
+ open fun assertNoCallback() = assertNoCallback(defaultNoCallbackTimeoutMs)
fun assertNoCallback(timeoutMs: Long) {
val cb = history.poll(timeoutMs)
@@ -202,7 +354,7 @@
}
fun assertNoCallbackThat(
- timeoutMs: Long = defaultTimeoutMs,
+ timeoutMs: Long = defaultNoCallbackTimeoutMs,
valid: (CallbackEntry) -> Boolean
) {
val cb = history.poll(timeoutMs) { valid(it) }.let {
@@ -210,19 +362,6 @@
}
}
- // Expects a callback of the specified type on the specified network within the timeout.
- // If no callback arrives, or a different callback arrives, fail. Returns the callback.
- inline fun <reified T : CallbackEntry> expectCallback(
- network: Network = ANY_NETWORK,
- timeoutMs: Long = defaultTimeoutMs
- ): T = pollForNextCallback(timeoutMs).let {
- if (it !is T || (ANY_NETWORK !== network && it.network != network)) {
- fail("Unexpected callback : $it, expected ${T::class} with Network[$network]")
- } else {
- it
- }
- }
-
// Expects a callback of the specified type matching the predicate within the timeout.
// Any callback that doesn't match the predicate will be skipped. Fails only if
// no matching callback is received within the timeout.
@@ -237,8 +376,17 @@
fun <T : CallbackEntry> eventuallyExpect(
type: KClass<T>,
timeoutMs: Long = defaultTimeoutMs,
- predicate: (T: CallbackEntry) -> Boolean = { true }
- ) = history.poll(timeoutMs) { type.java.isInstance(it) && predicate(it) }.also {
+ predicate: (cb: T) -> Boolean = { true }
+ ) = history.poll(timeoutMs) { type.java.isInstance(it) && predicate(it as T) }.also {
+ assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms")
+ } as T
+
+ fun <T : CallbackEntry> eventuallyExpect(
+ type: KClass<T>,
+ timeoutMs: Long = defaultTimeoutMs,
+ from: Int = mark,
+ predicate: (cb: T) -> Boolean = { true }
+ ) = history.poll(timeoutMs, from) { type.java.isInstance(it) && predicate(it as T) }.also {
assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms")
} as T
@@ -249,30 +397,19 @@
crossinline predicate: (T) -> Boolean = { true }
) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T?
- fun expectCallbackThat(
- timeoutMs: Long = defaultTimeoutMs,
- valid: (CallbackEntry) -> Boolean
- ) = pollForNextCallback(timeoutMs).also { assertTrue(valid(it), "Unexpected callback : $it") }
-
- fun expectCapabilitiesThat(
+ inline fun expectCapabilitiesThat(
net: Network,
tmt: Long = defaultTimeoutMs,
valid: (NetworkCapabilities) -> Boolean
- ): CapabilitiesChanged {
- return expectCallback<CapabilitiesChanged>(net, tmt).also {
- assertTrue(valid(it.caps), "Capabilities don't match expectations ${it.caps}")
- }
- }
+ ): CapabilitiesChanged =
+ expect(net, tmt, "Capabilities don't match expectations") { valid(it.caps) }
- fun expectLinkPropertiesThat(
+ inline fun expectLinkPropertiesThat(
net: Network,
tmt: Long = defaultTimeoutMs,
valid: (LinkProperties) -> Boolean
- ): LinkPropertiesChanged {
- return expectCallback<LinkPropertiesChanged>(net, tmt).also {
- assertTrue(valid(it.lp), "LinkProperties don't match expectations ${it.lp}")
- }
- }
+ ): LinkPropertiesChanged =
+ expect(net, tmt, "LinkProperties don't match expectations") { valid(it.lp) }
// Expects onAvailable and the callbacks that follow it. These are:
// - onSuspended, iff the network was suspended when the callbacks fire.
@@ -313,16 +450,16 @@
validated: Boolean?,
tmt: Long
) {
- expectCallback<Available>(net, tmt)
+ expect<Available>(net, tmt)
if (suspended) {
- expectCallback<Suspended>(net, tmt)
+ expect<Suspended>(net, tmt)
}
expectCapabilitiesThat(net, tmt) {
validated == null || validated == it.hasCapability(
NET_CAPABILITY_VALIDATED
)
}
- expectCallback<LinkPropertiesChanged>(net, tmt)
+ expect<LinkPropertiesChanged>(net, tmt)
}
// Backward compatibility for existing Java code. Use named arguments instead and remove all
@@ -333,17 +470,15 @@
tmt: Long = defaultTimeoutMs
) = expectAvailableCallbacks(net, suspended = true, validated = validated, tmt = tmt)
- fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) {
- expectCallback<BlockedStatus>(net, tmt).also {
- assertEquals(blocked, it.blocked, "Unexpected blocked status ${it.blocked}")
- }
- }
+ fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) =
+ expect<BlockedStatus>(net, tmt, "Unexpected blocked status") {
+ it.blocked == blocked
+ }
- fun expectBlockedStatusCallback(blocked: Int, net: Network, tmt: Long = defaultTimeoutMs) {
- expectCallback<BlockedStatusInt>(net, tmt).also {
- assertEquals(blocked, it.blocked, "Unexpected blocked status ${it.blocked}")
- }
- }
+ fun expectBlockedStatusCallback(blocked: Int, net: Network, tmt: Long = defaultTimeoutMs) =
+ expect<BlockedStatusInt>(net, tmt, "Unexpected blocked status") {
+ it.blocked == blocked
+ }
// Expects the available callbacks (where the onCapabilitiesChanged must contain the
// VALIDATED capability), plus another onCapabilitiesChanged which is identical to the
@@ -353,7 +488,7 @@
val mark = history.mark
expectAvailableCallbacks(net, tmt = tmt)
val firstCaps = history.poll(tmt, mark) { it is CapabilitiesChanged }
- assertEquals(firstCaps, expectCallback<CapabilitiesChanged>(net, tmt))
+ assertEquals(firstCaps, expect<CapabilitiesChanged>(net, tmt))
}
// Expects the available callbacks where the onCapabilitiesChanged must not have validated,
@@ -381,26 +516,6 @@
val network: Network
}
- @JvmOverloads
- open fun <T : CallbackEntry> expectCallback(
- type: KClass<T>,
- n: Network?,
- timeoutMs: Long = defaultTimeoutMs
- ) = pollForNextCallback(timeoutMs).also {
- val network = n ?: NULL_NETWORK
- // TODO : remove this .java access if the tests ever use kotlin-reflect. At the time of
- // this writing this would be the only use of this library in the tests.
- assertTrue(type.java.isInstance(it) && (ANY_NETWORK === n || it.network == network),
- "Unexpected callback : $it, expected ${type.java} with Network[$network]")
- } as T
-
- @JvmOverloads
- open fun <T : CallbackEntry> expectCallback(
- type: KClass<T>,
- n: HasNetwork?,
- timeoutMs: Long = defaultTimeoutMs
- ) = expectCallback(type, n?.network, timeoutMs)
-
fun expectAvailableCallbacks(
n: HasNetwork,
suspended: Boolean,