Merge "Revert "[ST03] Add test dns server for integration tests""
diff --git a/staticlibs/device/com/android/net/module/util/NetworkMonitorUtils.java b/staticlibs/device/com/android/net/module/util/NetworkMonitorUtils.java
index f6cd044..5a4412f 100644
--- a/staticlibs/device/com/android/net/module/util/NetworkMonitorUtils.java
+++ b/staticlibs/device/com/android/net/module/util/NetworkMonitorUtils.java
@@ -16,6 +16,7 @@
 
 package com.android.net.module.util;
 
+import static android.net.NetworkCapabilities.NET_CAPABILITY_DUN;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED;
 import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VPN;
@@ -102,9 +103,12 @@
      *                                networks.
      * @param nc Network capabilities of the network to test.
      */
-    public static boolean isValidationRequired(boolean isVpnValidationRequired,
+    public static boolean isValidationRequired(boolean isDunValidationRequired,
+            boolean isVpnValidationRequired,
             @NonNull final NetworkCapabilities nc) {
-        // TODO: Consider requiring validation for DUN networks.
+        if (isDunValidationRequired && nc.hasCapability(NET_CAPABILITY_DUN)) {
+            return true;
+        }
         if (!nc.hasCapability(NET_CAPABILITY_NOT_VPN)) {
             return isVpnValidationRequired;
         }
diff --git a/staticlibs/framework/com/android/net/module/util/BitUtils.java b/staticlibs/framework/com/android/net/module/util/BitUtils.java
new file mode 100644
index 0000000..2b32e86
--- /dev/null
+++ b/staticlibs/framework/com/android/net/module/util/BitUtils.java
@@ -0,0 +1,110 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util;
+
+import android.annotation.NonNull;
+
+/**
+ * @hide
+ */
+public class BitUtils {
+    /**
+     * Unpacks long value into an array of bits.
+     */
+    public static int[] unpackBits(long val) {
+        int size = Long.bitCount(val);
+        int[] result = new int[size];
+        int index = 0;
+        int bitPos = 0;
+        while (val != 0) {
+            if ((val & 1) == 1) result[index++] = bitPos;
+            val = val >>> 1;
+            bitPos++;
+        }
+        return result;
+    }
+
+    /**
+     * Packs a list of ints in the same way as packBits()
+     *
+     * Each passed int is the rank of a bit that should be set in the returned long.
+     * Example : passing (1,3) will return in 0b00001010 and passing (5,6,0) will return 0b01100001
+     *
+     * @param bits bits to pack
+     * @return a long with the specified bits set.
+     */
+    public static long packBitList(int... bits) {
+        return packBits(bits);
+    }
+
+    /**
+     * Packs array of bits into a long value.
+     *
+     * Each passed int is the rank of a bit that should be set in the returned long.
+     * Example : passing [1,3] will return in 0b00001010 and passing [5,6,0] will return 0b01100001
+     *
+     * @param bits bits to pack
+     * @return a long with the specified bits set.
+     */
+    public static long packBits(int[] bits) {
+        long packed = 0;
+        for (int b : bits) {
+            packed |= (1L << b);
+        }
+        return packed;
+    }
+
+    /**
+     * An interface for a function that can retrieve a name associated with an int.
+     *
+     * This is useful for bitfields like network capabilities or network score policies.
+     */
+    @FunctionalInterface
+    public interface NameOf {
+        /** Retrieve the name associated with the passed value */
+        String nameOf(int value);
+    }
+
+    /**
+     * Given a bitmask and a name fetcher, append names of all set bits to the builder
+     *
+     * This method takes all bit sets in the passed bitmask, will figure out the name associated
+     * with the weight of each bit with the passed name fetcher, and append each name to the
+     * passed StringBuilder, separated by the passed separator.
+     *
+     * For example, if the bitmask is 0110, and the name fetcher return "BIT_1" to "BIT_4" for
+     * numbers from 1 to 4, and the separator is "&", this method appends "BIT_2&BIT3" to the
+     * StringBuilder.
+     */
+    public static void appendStringRepresentationOfBitMaskToStringBuilder(@NonNull StringBuilder sb,
+            long bitMask, @NonNull NameOf nameFetcher, @NonNull String separator) {
+        int bitPos = 0;
+        boolean firstElementAdded = false;
+        while (bitMask != 0) {
+            if ((bitMask & 1) != 0) {
+                if (firstElementAdded) {
+                    sb.append(separator);
+                } else {
+                    firstElementAdded = true;
+                }
+                sb.append(nameFetcher.nameOf(bitPos));
+            }
+            bitMask >>>= 1;
+            ++bitPos;
+        }
+    }
+}
diff --git a/staticlibs/framework/com/android/net/module/util/ByteUtils.java b/staticlibs/framework/com/android/net/module/util/ByteUtils.java
new file mode 100644
index 0000000..290ed46
--- /dev/null
+++ b/staticlibs/framework/com/android/net/module/util/ByteUtils.java
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util;
+
+import android.annotation.NonNull;
+
+/**
+ * Byte utility functions.
+ * @hide
+ */
+public class ByteUtils {
+    /**
+     * Returns the index of the first appearance of the value {@code target} in {@code array}.
+     *
+     * @param array an array of {@code byte} values, possibly empty
+     * @param target a primitive {@code byte} value
+     * @return the least index {@code i} for which {@code array[i] == target}, or {@code -1} if no
+     *     such index exists.
+     */
+    public static int indexOf(@NonNull byte[] array, byte target) {
+        return indexOf(array, target, 0, array.length);
+    }
+
+    private static int indexOf(byte[] array, byte target, int start, int end) {
+        for (int i = start; i < end; i++) {
+            if (array[i] == target) {
+                return i;
+            }
+        }
+        return -1;
+    }
+
+    /**
+     * Returns the values from each provided array combined into a single array. For example, {@code
+     * concat(new byte[] {a, b}, new byte[] {}, new byte[] {c}} returns the array {@code {a, b, c}}.
+     *
+     * @param arrays zero or more {@code byte} arrays
+     * @return a single array containing all the values from the source arrays, in order
+     */
+    public static byte[] concat(@NonNull byte[]... arrays) {
+        int length = 0;
+        for (byte[] array : arrays) {
+            length += array.length;
+        }
+        byte[] result = new byte[length];
+        int pos = 0;
+        for (byte[] array : arrays) {
+            System.arraycopy(array, 0, result, pos, array.length);
+            pos += array.length;
+        }
+        return result;
+    }
+}
diff --git a/staticlibs/framework/com/android/net/module/util/CollectionUtils.java b/staticlibs/framework/com/android/net/module/util/CollectionUtils.java
index 7cac90d..f08880c 100644
--- a/staticlibs/framework/com/android/net/module/util/CollectionUtils.java
+++ b/staticlibs/framework/com/android/net/module/util/CollectionUtils.java
@@ -18,12 +18,15 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.util.ArrayMap;
+import android.util.Pair;
 import android.util.SparseArray;
 
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
 import java.util.Objects;
+import java.util.function.Function;
 import java.util.function.Predicate;
 
 /**
@@ -291,4 +294,100 @@
             @NonNull final Predicate<? super T> condition) {
         return -1 != indexOf(haystack, condition);
     }
+
+    /**
+     * Standard map function, but returns a new modifiable ArrayList
+     *
+     * This returns a new list that contains, for each element of the source collection, its
+     * image through the passed transform.
+     * Elements in the source can be null if the transform accepts null inputs.
+     * Elements in the output can be null if the transform ever returns null.
+     * This function never returns null. If the source collection is empty, it returns the
+     * empty list.
+     * Contract : this method calls the transform function exactly once for each element in the
+     * list, in iteration order.
+     *
+     * @param source the source collection
+     * @param transform the function to transform the elements
+     * @param <T> type of source elements
+     * @param <R> type of destination elements
+     * @return an unmodifiable list of transformed elements
+     */
+    @NonNull
+    public static <T, R> ArrayList<R> map(@NonNull final Collection<T> source,
+            @NonNull final Function<? super T, ? extends R> transform) {
+        final ArrayList<R> dest = new ArrayList<>(source.size());
+        for (final T e : source) {
+            dest.add(transform.apply(e));
+        }
+        return dest;
+    }
+
+    /**
+     * Standard zip function, but returns a new modifiable ArrayList
+     *
+     * This returns a list of pairs containing, at each position, a pair of the element from the
+     * first list at that index and the element from the second list at that index.
+     * Both lists must be the same size. They may contain null.
+     *
+     * The easiest way to visualize what's happening is to think of two lists being laid out next
+     * to each other and stitched together with a zipper.
+     *
+     * Contract : this method will read each element of each list exactly once, in some unspecified
+     * order. If it throws, it will not read any element.
+     *
+     * @param first the first list of elements
+     * @param second the second list of elements
+     * @param <T> the type of first elements
+     * @param <R> the type of second elements
+     * @return the zipped list
+     */
+    @NonNull
+    public static <T, R> ArrayList<Pair<T, R>> zip(@NonNull final List<T> first,
+            @NonNull final List<R> second) {
+        final int size = first.size();
+        if (size != second.size()) {
+            throw new IllegalArgumentException("zip : collections must be the same size");
+        }
+        final ArrayList<Pair<T, R>> dest = new ArrayList<>(size);
+        for (int i = 0; i < size; ++i) {
+            dest.add(new Pair<>(first.get(i), second.get(i)));
+        }
+        return dest;
+    }
+
+    /**
+     * Returns a new ArrayMap that associates each key with the value at the same index.
+     *
+     * Both lists must be the same size.
+     * Both keys and values may contain null.
+     * Keys may not contain the same value twice.
+     *
+     * Contract : this method will read each element of each list exactly once, but does not
+     * specify the order, except if it throws in which case the number of reads is undefined.
+     *
+     * @param keys The list of keys
+     * @param values The list of values
+     * @param <T> The type of keys
+     * @param <R> The type of values
+     * @return The associated map
+     */
+    @NonNull
+    public static <T, R> ArrayMap<T, R> assoc(
+            @NonNull final List<T> keys, @NonNull final List<R> values) {
+        final int size = keys.size();
+        if (size != values.size()) {
+            throw new IllegalArgumentException("assoc : collections must be the same size");
+        }
+        final ArrayMap<T, R> dest = new ArrayMap<>(size);
+        for (int i = 0; i < size; ++i) {
+            final T key = keys.get(i);
+            if (dest.containsKey(key)) {
+                throw new IllegalArgumentException(
+                        "assoc : keys may not contain the same value twice");
+            }
+            dest.put(key, values.get(i));
+        }
+        return dest;
+    }
 }
diff --git a/staticlibs/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java b/staticlibs/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java
index 26c24f8..54ce01e 100644
--- a/staticlibs/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java
+++ b/staticlibs/framework/com/android/net/module/util/NetworkCapabilitiesUtils.java
@@ -44,6 +44,9 @@
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI_AWARE;
 
+import static com.android.net.module.util.BitUtils.packBitList;
+import static com.android.net.module.util.BitUtils.unpackBits;
+
 import android.annotation.NonNull;
 import android.net.NetworkCapabilities;
 
@@ -181,49 +184,4 @@
         return false;
     }
 
-    /**
-     * Unpacks long value into an array of bits.
-     */
-    public static int[] unpackBits(long val) {
-        int size = Long.bitCount(val);
-        int[] result = new int[size];
-        int index = 0;
-        int bitPos = 0;
-        while (val != 0) {
-            if ((val & 1) == 1) result[index++] = bitPos;
-            val = val >>> 1;
-            bitPos++;
-        }
-        return result;
-    }
-
-    /**
-     * Packs a list of ints in the same way as packBits()
-     *
-     * Each passed int is the rank of a bit that should be set in the returned long.
-     * Example : passing (1,3) will return in 0b00001010 and passing (5,6,0) will return 0b01100001
-     *
-     * @param bits bits to pack
-     * @return a long with the specified bits set.
-     */
-    public static long packBitList(int... bits) {
-        return packBits(bits);
-    }
-
-    /**
-     * Packs array of bits into a long value.
-     *
-     * Each passed int is the rank of a bit that should be set in the returned long.
-     * Example : passing [1,3] will return in 0b00001010 and passing [5,6,0] will return 0b01100001
-     *
-     * @param bits bits to pack
-     * @return a long with the specified bits set.
-     */
-    public static long packBits(int[] bits) {
-        long packed = 0;
-        for (int b : bits) {
-            packed |= (1L << b);
-        }
-        return packed;
-    }
 }
diff --git a/staticlibs/native/bpf_headers/include/bpf/BpfUtils.h b/staticlibs/native/bpf_headers/include/bpf/BpfUtils.h
index 4429164..157f210 100644
--- a/staticlibs/native/bpf_headers/include/bpf/BpfUtils.h
+++ b/staticlibs/native/bpf_headers/include/bpf/BpfUtils.h
@@ -115,23 +115,16 @@
     return kernelVersion() >= KVER(major, minor, sub);
 }
 
-#define SKIP_IF_BPF_SUPPORTED                              \
-    do {                                                   \
-        if (android::bpf::isAtLeastKernelVersion(4, 9, 0)) \
-            GTEST_SKIP() << "Skip: bpf is supported.";     \
-    } while (0)
-
 #define SKIP_IF_BPF_NOT_SUPPORTED                           \
     do {                                                    \
         if (!android::bpf::isAtLeastKernelVersion(4, 9, 0)) \
             GTEST_SKIP() << "Skip: bpf is not supported.";  \
     } while (0)
 
-#define SKIP_IF_EXTENDED_BPF_NOT_SUPPORTED                               \
-    do {                                                                 \
-        if (!android::bpf::isAtLeastKernelVersion(4, 14, 0))             \
-            GTEST_SKIP() << "Skip: extended bpf feature not supported."; \
-    } while (0)
+// Only used by tm-mainline-prod's system/netd/tests/bpf_base_test.cpp
+// but platform and platform tests aren't expected to build/work in tm-mainline-prod
+// so we can just trivialize this
+#define SKIP_IF_EXTENDED_BPF_NOT_SUPPORTED
 
 #define SKIP_IF_XDP_NOT_SUPPORTED                           \
     do {                                                    \
diff --git a/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h b/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
index 35d9f31..c652c76 100644
--- a/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
+++ b/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
@@ -147,6 +147,23 @@
         __attribute__ ((section(".maps." #name), used)) \
                 ____btf_map_##name = { }
 
+/* There exist buggy kernels with pre-T OS, that due to
+ * kernel patch "[ALPS05162612] bpf: fix ubsan error"
+ * do not support userspace writes into non-zero index of bpf map arrays.
+ *
+ * We use this assert to prevent us from being able to define such a map.
+ */
+
+#ifdef THIS_BPF_PROGRAM_IS_FOR_TEST_PURPOSES_ONLY
+#define BPF_MAP_ASSERT_OK(type, entries, mode)
+#elif BPFLOADER_MIN_VER >= BPFLOADER_T_BETA3_VERSION
+#define BPF_MAP_ASSERT_OK(type, entries, mode)
+#else
+#define BPF_MAP_ASSERT_OK(type, entries, mode) \
+  _Static_assert(((type) != BPF_MAP_TYPE_ARRAY) || ((entries) <= 1) || !((mode) & 0222), \
+  "Writable arrays with more than 1 element not supported on pre-T devices.")
+#endif
+
 /* type safe macro to declare a map and related accessor functions */
 #define DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md,         \
                            selinux, pindir, share)                                               \
@@ -167,6 +184,7 @@
             .pin_subdir = pindir,                                                                \
             .shared = share,                                                                     \
     };                                                                                           \
+    BPF_MAP_ASSERT_OK(BPF_MAP_TYPE_##TYPE, (num_entries), (md));                                 \
     BPF_ANNOTATE_KV_PAIR(the_map, KeyType, ValueType);                                           \
                                                                                                  \
     static inline __always_inline __unused ValueType* bpf_##the_map##_lookup_elem(               \
@@ -205,6 +223,10 @@
     DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \
                        DEFAULT_BPF_MAP_UID, AID_ROOT, 0600)
 
+#define DEFINE_BPF_MAP_RO(the_map, TYPE, KeyType, ValueType, num_entries, gid) \
+    DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \
+                       DEFAULT_BPF_MAP_UID, gid, 0440)
+
 #define DEFINE_BPF_MAP_GWO(the_map, TYPE, KeyType, ValueType, num_entries, gid) \
     DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \
                        DEFAULT_BPF_MAP_UID, gid, 0620)
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt b/staticlibs/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt
new file mode 100644
index 0000000..0236716
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util
+
+import com.android.net.module.util.BitUtils.appendStringRepresentationOfBitMaskToStringBuilder
+import com.android.net.module.util.BitUtils.packBits
+import com.android.net.module.util.BitUtils.unpackBits
+import org.junit.Test
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
+
+class BitUtilsTests {
+    @Test
+    fun testBitPackingTestCase() {
+        runBitPackingTestCase(0, intArrayOf())
+        runBitPackingTestCase(1, intArrayOf(0))
+        runBitPackingTestCase(3, intArrayOf(0, 1))
+        runBitPackingTestCase(4, intArrayOf(2))
+        runBitPackingTestCase(63, intArrayOf(0, 1, 2, 3, 4, 5))
+        runBitPackingTestCase(Long.MAX_VALUE.inv(), intArrayOf(63))
+        runBitPackingTestCase(Long.MAX_VALUE.inv() + 1, intArrayOf(0, 63))
+        runBitPackingTestCase(Long.MAX_VALUE.inv() + 2, intArrayOf(1, 63))
+    }
+
+    fun runBitPackingTestCase(packedBits: Long, bits: IntArray) {
+        assertEquals(packedBits, packBits(bits))
+        assertTrue(bits contentEquals unpackBits(packedBits))
+    }
+
+    @Test
+    fun testAppendStringRepresentationOfBitMaskToStringBuilder() {
+        runTestAppendStringRepresentationOfBitMaskToStringBuilder("", 0)
+        runTestAppendStringRepresentationOfBitMaskToStringBuilder("BIT0", 0b1)
+        runTestAppendStringRepresentationOfBitMaskToStringBuilder("BIT1&BIT2&BIT4", 0b10110)
+        runTestAppendStringRepresentationOfBitMaskToStringBuilder(
+                "BIT0&BIT60&BIT61&BIT62&BIT63",
+                (0b11110000_00000000_00000000_00000000 shl 32) +
+                        0b00000000_00000000_00000000_00000001)
+    }
+
+    fun runTestAppendStringRepresentationOfBitMaskToStringBuilder(expected: String, bitMask: Long) {
+        StringBuilder().let {
+            appendStringRepresentationOfBitMaskToStringBuilder(it, bitMask, { i -> "BIT$i" }, "&")
+            assertEquals(expected, it.toString())
+        }
+    }
+}
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/ByteUtilsTests.kt b/staticlibs/tests/unit/src/com/android/net/module/util/ByteUtilsTests.kt
new file mode 100644
index 0000000..e58adad
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/ByteUtilsTests.kt
@@ -0,0 +1,54 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.net.module.util
+
+import com.android.net.module.util.ByteUtils.indexOf
+import com.android.net.module.util.ByteUtils.concat
+import org.junit.Test
+import kotlin.test.assertContentEquals
+import kotlin.test.assertEquals
+import kotlin.test.assertNotSame
+
+class ByteUtilsTests {
+    private val EMPTY = byteArrayOf()
+    private val ARRAY1 = byteArrayOf(1)
+    private val ARRAY234 = byteArrayOf(2, 3, 4)
+
+    @Test
+    fun testIndexOf() {
+        assertEquals(-1, indexOf(EMPTY, 1))
+        assertEquals(-1, indexOf(ARRAY1, 2))
+        assertEquals(-1, indexOf(ARRAY234, 1))
+        assertEquals(0, indexOf(byteArrayOf(-1), -1))
+        assertEquals(0, indexOf(ARRAY234, 2))
+        assertEquals(1, indexOf(ARRAY234, 3))
+        assertEquals(2, indexOf(ARRAY234, 4))
+        assertEquals(1, indexOf(byteArrayOf(2, 3, 2, 3), 3))
+    }
+
+    @Test
+    fun testConcat() {
+        assertContentEquals(EMPTY, concat())
+        assertContentEquals(EMPTY, concat(EMPTY))
+        assertContentEquals(EMPTY, concat(EMPTY, EMPTY, EMPTY))
+        assertContentEquals(ARRAY1, concat(ARRAY1))
+        assertNotSame(ARRAY1, concat(ARRAY1))
+        assertContentEquals(ARRAY1, concat(EMPTY, ARRAY1, EMPTY))
+        assertContentEquals(byteArrayOf(1, 1, 1), concat(ARRAY1, ARRAY1, ARRAY1))
+        assertContentEquals(byteArrayOf(1, 2, 3, 4), concat(ARRAY1, ARRAY234))
+    }
+}
\ No newline at end of file
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/CollectionUtilsTest.kt b/staticlibs/tests/unit/src/com/android/net/module/util/CollectionUtilsTest.kt
index 0f00d0b..9fb025b 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/CollectionUtilsTest.kt
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/CollectionUtilsTest.kt
@@ -22,6 +22,7 @@
 import org.junit.Test
 import org.junit.runner.RunWith
 import kotlin.test.assertEquals
+import kotlin.test.assertFailsWith
 import kotlin.test.assertFalse
 import kotlin.test.assertNull
 import kotlin.test.assertSame
@@ -133,4 +134,52 @@
         assertSame(CollectionUtils.findLast(listMulti) { it == "E" }, listMulti[7])
         assertNull(CollectionUtils.findLast(listMulti) { it == "F" })
     }
+
+    @Test
+    fun testMap() {
+        val listAE = listOf("A", "B", "C", "D", "E", null)
+        assertEquals(listAE.map { "-$it-" }, CollectionUtils.map(listAE) { "-$it-" })
+    }
+
+    @Test
+    fun testZip() {
+        val listAE = listOf("A", "B", "C", "D", "E")
+        val list15 = listOf(1, 2, 3, 4, 5)
+        // Normal #zip returns kotlin.Pair, not android.util.Pair
+        assertEquals(list15.zip(listAE).map { android.util.Pair(it.first, it.second) },
+                CollectionUtils.zip(list15, listAE))
+        val listNull = listOf("A", null, "B", "C", "D")
+        assertEquals(list15.zip(listNull).map { android.util.Pair(it.first, it.second) },
+                CollectionUtils.zip(list15, listNull))
+        assertEquals(emptyList<android.util.Pair<Int, Int>>(),
+                CollectionUtils.zip(emptyList<Int>(), emptyList<Int>()))
+        assertFailsWith<IllegalArgumentException> {
+            // Different size
+            CollectionUtils.zip(listOf(1, 2), list15)
+        }
+    }
+
+    @Test
+    fun testAssoc() {
+        val listADA = listOf("A", "B", "C", "D", "A")
+        val list15 = listOf(1, 2, 3, 4, 5)
+        assertEquals(list15.zip(listADA).toMap(), CollectionUtils.assoc(list15, listADA))
+
+        // Null key is fine
+        val assoc = CollectionUtils.assoc(listOf(1, 2, null), listOf("A", "B", "C"))
+        assertEquals("C", assoc[null])
+
+        assertFailsWith<IllegalArgumentException> {
+            // Same key multiple times
+            CollectionUtils.assoc(listOf("A", "B", "A"), listOf(1, 2, 3))
+        }
+        assertFailsWith<IllegalArgumentException> {
+            // Same key multiple times, but it's null
+            CollectionUtils.assoc(listOf(null, "B", null), listOf(1, 2, 3))
+        }
+        assertFailsWith<IllegalArgumentException> {
+            // Different size
+            CollectionUtils.assoc(listOf(1, 2), list15)
+        }
+    }
 }
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt b/staticlibs/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt
index 256ea1e..958f45f 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/NetworkCapabilitiesUtilsTest.kt
@@ -22,7 +22,6 @@
 import android.net.NetworkCapabilities.NET_CAPABILITY_CBS
 import android.net.NetworkCapabilities.NET_CAPABILITY_EIMS
 import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET
-import android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET
 import android.net.NetworkCapabilities.NET_CAPABILITY_OEM_PAID
 import android.net.NetworkCapabilities.TRANSPORT_BLUETOOTH
 import android.net.NetworkCapabilities.TRANSPORT_CELLULAR
@@ -38,8 +37,6 @@
 import com.android.net.module.util.NetworkCapabilitiesUtils.RESTRICTED_CAPABILITIES
 import com.android.net.module.util.NetworkCapabilitiesUtils.UNRESTRICTED_CAPABILITIES
 import com.android.net.module.util.NetworkCapabilitiesUtils.getDisplayTransport
-import com.android.net.module.util.NetworkCapabilitiesUtils.packBits
-import com.android.net.module.util.NetworkCapabilitiesUtils.unpackBits
 import org.junit.Test
 import org.junit.runner.RunWith
 import java.lang.IllegalArgumentException
@@ -75,23 +72,6 @@
         }
     }
 
-    @Test
-    fun testBitPackingTestCase() {
-        runBitPackingTestCase(0, intArrayOf())
-        runBitPackingTestCase(1, intArrayOf(0))
-        runBitPackingTestCase(3, intArrayOf(0, 1))
-        runBitPackingTestCase(4, intArrayOf(2))
-        runBitPackingTestCase(63, intArrayOf(0, 1, 2, 3, 4, 5))
-        runBitPackingTestCase(Long.MAX_VALUE.inv(), intArrayOf(63))
-        runBitPackingTestCase(Long.MAX_VALUE.inv() + 1, intArrayOf(0, 63))
-        runBitPackingTestCase(Long.MAX_VALUE.inv() + 2, intArrayOf(1, 63))
-    }
-
-    fun runBitPackingTestCase(packedBits: Long, bits: IntArray) {
-        assertEquals(packedBits, packBits(bits))
-        assertTrue(bits contentEquals unpackBits(packedBits))
-    }
-
     // NetworkCapabilities constructor and Builder are not available until R. Mark TargetApi to
     // ignore the linter error since it's used in only unit test.
     @Test @TargetApi(Build.VERSION_CODES.R)
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt b/staticlibs/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt
index 9fb4d8c..8e320d0 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/TrackRecordTest.kt
@@ -17,6 +17,7 @@
 package com.android.net.module.util
 
 import com.android.testutils.ConcurrentInterpreter
+import com.android.testutils.INTERPRET_TIME_UNIT
 import com.android.testutils.InterpretException
 import com.android.testutils.InterpretMatcher
 import com.android.testutils.SyntaxException
@@ -420,7 +421,7 @@
     // the test code to not compile instead of throw, but it's vastly more complex and this will
     // fail 100% at runtime any test that would not have compiled.
     Regex("""poll\((\d+)?\)\s*(\{.*\})?""") to { i, t, r ->
-        (if (r.strArg(1).isEmpty()) i.interpretTimeUnit else r.timeArg(1)).let { time ->
+        (if (r.strArg(1).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(1)).let { time ->
             (t as ArrayTrackRecord<Int>.ReadHead).poll(time, makePredicate(r.strArg(2)))
         }
     },
diff --git a/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
new file mode 100644
index 0000000..f8f2da0
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
@@ -0,0 +1,477 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.annotation.SuppressLint
+import android.net.LinkAddress
+import android.net.LinkProperties
+import android.net.Network
+import android.net.NetworkCapabilities
+import com.android.testutils.RecorderCallback.CallbackEntry
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.AVAILABLE
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.BLOCKED_STATUS
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LINK_PROPERTIES_CHANGED
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LOSING
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.NETWORK_CAPS_UPDATED
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.LOST
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.RESUMED
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.SUSPENDED
+import com.android.testutils.RecorderCallback.CallbackEntry.Companion.UNAVAILABLE
+import com.android.testutils.RecorderCallback.CallbackEntry.Available
+import com.android.testutils.RecorderCallback.CallbackEntry.BlockedStatus
+import com.android.testutils.RecorderCallback.CallbackEntry.CapabilitiesChanged
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+import org.junit.Assume.assumeTrue
+import kotlin.reflect.KClass
+import kotlin.test.assertEquals
+import kotlin.test.assertFails
+import kotlin.test.assertNull
+import kotlin.test.assertTrue
+import kotlin.test.fail
+
+const val SHORT_TIMEOUT_MS = 20L
+const val DEFAULT_LINGER_DELAY_MS = 30000
+const val NOT_METERED = NetworkCapabilities.NET_CAPABILITY_NOT_METERED
+const val WIFI = NetworkCapabilities.TRANSPORT_WIFI
+const val CELLULAR = NetworkCapabilities.TRANSPORT_CELLULAR
+const val TEST_INTERFACE_NAME = "testInterfaceName"
+
+@RunWith(JUnit4::class)
+@SuppressLint("NewApi") // Uses hidden APIs, which the linter would identify as missing APIs.
+class TestableNetworkCallbackTest {
+    private lateinit var mCallback: TestableNetworkCallback
+
+    private fun makeHasNetwork(netId: Int) = object : TestableNetworkCallback.HasNetwork {
+        override val network: Network = Network(netId)
+    }
+
+    @Before
+    fun setUp() {
+        mCallback = TestableNetworkCallback()
+    }
+
+    @Test
+    fun testLastAvailableNetwork() {
+        // Make sure there is no last available network at first, then the last available network
+        // is returned after onAvailable is called.
+        val net2097 = Network(2097)
+        assertNull(mCallback.lastAvailableNetwork)
+        mCallback.onAvailable(net2097)
+        assertEquals(mCallback.lastAvailableNetwork, net2097)
+
+        // Make sure calling onCapsChanged/onLinkPropertiesChanged don't affect the last available
+        // network.
+        mCallback.onCapabilitiesChanged(net2097, NetworkCapabilities())
+        mCallback.onLinkPropertiesChanged(net2097, LinkProperties())
+        assertEquals(mCallback.lastAvailableNetwork, net2097)
+
+        // Make sure onLost clears the last available network.
+        mCallback.onLost(net2097)
+        assertNull(mCallback.lastAvailableNetwork)
+
+        // Do the same but with a different network after onLost : make sure the last available
+        // network is the new one, not the original one.
+        val net2098 = Network(2098)
+        mCallback.onAvailable(net2098)
+        mCallback.onCapabilitiesChanged(net2098, NetworkCapabilities())
+        mCallback.onLinkPropertiesChanged(net2098, LinkProperties())
+        assertEquals(mCallback.lastAvailableNetwork, net2098)
+
+        // Make sure onAvailable changes the last available network even if onLost was not called.
+        val net2099 = Network(2099)
+        mCallback.onAvailable(net2099)
+        assertEquals(mCallback.lastAvailableNetwork, net2099)
+
+        // For legacy reasons, lastAvailableNetwork is null as soon as any is lost, not necessarily
+        // the last available one. Check that behavior.
+        mCallback.onLost(net2098)
+        assertNull(mCallback.lastAvailableNetwork)
+
+        // Make sure that losing the really last available one still results in null.
+        mCallback.onLost(net2099)
+        assertNull(mCallback.lastAvailableNetwork)
+
+        // Make sure multiple onAvailable in a row then onLost still results in null.
+        mCallback.onAvailable(net2097)
+        mCallback.onAvailable(net2098)
+        mCallback.onAvailable(net2099)
+        mCallback.onLost(net2097)
+        assertNull(mCallback.lastAvailableNetwork)
+    }
+
+    @Test
+    fun testAssertNoCallback() {
+        mCallback.assertNoCallback(SHORT_TIMEOUT_MS)
+        mCallback.onAvailable(Network(100))
+        assertFails { mCallback.assertNoCallback(SHORT_TIMEOUT_MS) }
+    }
+
+    @Test
+    fun testAssertNoCallbackThat() {
+        val net = Network(101)
+        mCallback.assertNoCallbackThat { it is Available }
+        mCallback.onAvailable(net)
+        // Expect no blocked status change. Receive other callback does not fail the test.
+        mCallback.assertNoCallbackThat { it is BlockedStatus }
+        mCallback.onBlockedStatusChanged(net, true)
+        assertFails { mCallback.assertNoCallbackThat { it is BlockedStatus } }
+        mCallback.onBlockedStatusChanged(net, false)
+        mCallback.onCapabilitiesChanged(net, NetworkCapabilities())
+        assertFails { mCallback.assertNoCallbackThat { it is CapabilitiesChanged } }
+    }
+
+    @Test
+    fun testCapabilitiesWithAndWithout() {
+        val net = Network(101)
+        val matcher = makeHasNetwork(101)
+        val meteredNc = NetworkCapabilities()
+        val unmeteredNc = NetworkCapabilities().addCapability(NOT_METERED)
+        // Check that expecting caps (with or without) fails when no callback has been received.
+        assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+        assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+
+        // Add NOT_METERED and check that With succeeds and Without fails.
+        mCallback.onCapabilitiesChanged(net, unmeteredNc)
+        mCallback.expectCapabilitiesWith(NOT_METERED, matcher)
+        mCallback.onCapabilitiesChanged(net, unmeteredNc)
+        assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+
+        // Don't add NOT_METERED and check that With fails and Without succeeds.
+        mCallback.onCapabilitiesChanged(net, meteredNc)
+        assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+        mCallback.onCapabilitiesChanged(net, meteredNc)
+        mCallback.expectCapabilitiesWithout(NOT_METERED, matcher)
+    }
+
+    @Test
+    fun testExpectWithPredicate() {
+        val net = Network(193)
+        val netCaps = NetworkCapabilities().addTransportType(CELLULAR)
+        // Check that expecting callbackThat anything fails when no callback has been received.
+        assertFails { mCallback.expect<CallbackEntry>(timeoutMs = SHORT_TIMEOUT_MS) { true } }
+
+        // Basic test for true and false
+        mCallback.onAvailable(net)
+        mCallback.expect<Available> { true }
+        mCallback.onAvailable(net)
+        assertFails { mCallback.expect<CallbackEntry>(timeoutMs = SHORT_TIMEOUT_MS) { false } }
+
+        // Try a positive and a negative case
+        mCallback.onBlockedStatusChanged(net, true)
+        mCallback.expect<CallbackEntry> { cb -> cb is BlockedStatus && cb.blocked }
+        mCallback.onCapabilitiesChanged(net, netCaps)
+        assertFails { mCallback.expect<CallbackEntry>(timeoutMs = SHORT_TIMEOUT_MS) { cb ->
+            cb is CapabilitiesChanged && cb.caps.hasTransport(WIFI)
+        } }
+    }
+
+    @Test
+    fun testCapabilitiesThat() {
+        val net = Network(101)
+        val netCaps = NetworkCapabilities().addCapability(NOT_METERED).addTransportType(WIFI)
+        // Check that expecting capabilitiesThat anything fails when no callback has been received.
+        assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { true } }
+
+        // Basic test for true and false
+        mCallback.onCapabilitiesChanged(net, netCaps)
+        mCallback.expectCapabilitiesThat(net) { true }
+        mCallback.onCapabilitiesChanged(net, netCaps)
+        assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { false } }
+
+        // Try a positive and a negative case
+        mCallback.onCapabilitiesChanged(net, netCaps)
+        mCallback.expectCapabilitiesThat(net) { caps ->
+            caps.hasCapability(NOT_METERED) &&
+                    caps.hasTransport(WIFI) &&
+                    !caps.hasTransport(CELLULAR)
+        }
+        mCallback.onCapabilitiesChanged(net, netCaps)
+        assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { caps ->
+            caps.hasTransport(CELLULAR)
+        } }
+
+        // Try a matching callback on the wrong network
+        mCallback.onCapabilitiesChanged(net, netCaps)
+        assertFails { mCallback.expectCapabilitiesThat(Network(100), SHORT_TIMEOUT_MS) { true } }
+    }
+
+    @Test
+    fun testLinkPropertiesThat() {
+        val net = Network(112)
+        val linkAddress = LinkAddress("fe80::ace:d00d/64")
+        val mtu = 1984
+        val linkProps = LinkProperties().apply {
+            this.mtu = mtu
+            interfaceName = TEST_INTERFACE_NAME
+            addLinkAddress(linkAddress)
+        }
+
+        // Check that expecting linkPropsThat anything fails when no callback has been received.
+        assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { true } }
+
+        // Basic test for true and false
+        mCallback.onLinkPropertiesChanged(net, linkProps)
+        mCallback.expectLinkPropertiesThat(net) { true }
+        mCallback.onLinkPropertiesChanged(net, linkProps)
+        assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { false } }
+
+        // Try a positive and negative case
+        mCallback.onLinkPropertiesChanged(net, linkProps)
+        mCallback.expectLinkPropertiesThat(net) { lp ->
+            lp.interfaceName == TEST_INTERFACE_NAME &&
+                    lp.linkAddresses.contains(linkAddress) &&
+                    lp.mtu == mtu
+        }
+        mCallback.onLinkPropertiesChanged(net, linkProps)
+        assertFails { mCallback.expectLinkPropertiesThat(net, SHORT_TIMEOUT_MS) { lp ->
+            lp.interfaceName != TEST_INTERFACE_NAME
+        } }
+
+        // Try a matching callback on the wrong network
+        mCallback.onLinkPropertiesChanged(net, linkProps)
+        assertFails { mCallback.expectLinkPropertiesThat(Network(114), SHORT_TIMEOUT_MS) { lp ->
+            lp.interfaceName == TEST_INTERFACE_NAME
+        } }
+    }
+
+    @Test
+    fun testExpect() {
+        val net = Network(103)
+        // Test expectCallback fails when nothing was sent.
+        assertFails { mCallback.expect<BlockedStatus>(net, SHORT_TIMEOUT_MS) }
+
+        // Test onAvailable is seen and can be expected
+        mCallback.onAvailable(net)
+        mCallback.expect<Available>(net, SHORT_TIMEOUT_MS)
+
+        // Test onAvailable won't return calls with a different network
+        mCallback.onAvailable(Network(106))
+        assertFails { mCallback.expect<Available>(net, SHORT_TIMEOUT_MS) }
+
+        // Test onAvailable won't return calls with a different callback
+        mCallback.onAvailable(net)
+        assertFails { mCallback.expect<BlockedStatus>(net, SHORT_TIMEOUT_MS) }
+    }
+
+    @Test
+    fun testAllExpectOverloads() {
+        // This test should never run, it only checks that all overloads exist and build
+        assumeTrue(false)
+        val hn = object : TestableNetworkCallback.HasNetwork { override val network = ANY_NETWORK }
+
+        // Method with all arguments (version that takes a Network)
+        mCallback.expect(AVAILABLE, ANY_NETWORK, 10, "error") { true }
+
+        // Java overloads omitting one argument. One line for omitting each argument, in positional
+        // order. Versions that take a Network.
+        mCallback.expect(AVAILABLE, 10, "error") { true }
+        mCallback.expect(AVAILABLE, ANY_NETWORK, "error") { true }
+        mCallback.expect(AVAILABLE, ANY_NETWORK, 10) { true }
+        mCallback.expect(AVAILABLE, ANY_NETWORK, 10, "error")
+
+        // Java overloads for omitting two arguments. One line for omitting each pair of arguments.
+        // Versions that take a Network.
+        mCallback.expect(AVAILABLE, "error") { true }
+        mCallback.expect(AVAILABLE, 10) { true }
+        mCallback.expect(AVAILABLE, 10, "error")
+        mCallback.expect(AVAILABLE, ANY_NETWORK) { true }
+        mCallback.expect(AVAILABLE, ANY_NETWORK, "error")
+        mCallback.expect(AVAILABLE, ANY_NETWORK, 10)
+
+        // Java overloads for omitting three arguments. One line for each remaining argument.
+        // Versions that take a Network.
+        mCallback.expect(AVAILABLE) { true }
+        mCallback.expect(AVAILABLE, "error")
+        mCallback.expect(AVAILABLE, 10)
+        mCallback.expect(AVAILABLE, ANY_NETWORK)
+
+        // Java overload for omitting all four arguments.
+        mCallback.expect(AVAILABLE)
+
+        // Same orders as above, but versions that take a HasNetwork. Except overloads that
+        // were already tested because they omitted the Network argument
+        mCallback.expect(AVAILABLE, hn, 10, "error") { true }
+        mCallback.expect(AVAILABLE, hn, "error") { true }
+        mCallback.expect(AVAILABLE, hn, 10) { true }
+        mCallback.expect(AVAILABLE, hn, 10, "error")
+
+        mCallback.expect(AVAILABLE, hn) { true }
+        mCallback.expect(AVAILABLE, hn, "error")
+        mCallback.expect(AVAILABLE, hn, 10)
+
+        mCallback.expect(AVAILABLE, hn)
+
+        // Same as above but for reified versions.
+        mCallback.expect<Available>(ANY_NETWORK, 10, "error") { true }
+        mCallback.expect<Available>(timeoutMs = 10, errorMsg = "error") { true }
+        mCallback.expect<Available>(network = ANY_NETWORK, errorMsg = "error") { true }
+        mCallback.expect<Available>(network = ANY_NETWORK, timeoutMs = 10) { true }
+        mCallback.expect<Available>(network = ANY_NETWORK, timeoutMs = 10, errorMsg = "error")
+
+        mCallback.expect<Available>(errorMsg = "error") { true }
+        mCallback.expect<Available>(timeoutMs = 10) { true }
+        mCallback.expect<Available>(timeoutMs = 10, errorMsg = "error")
+        mCallback.expect<Available>(network = ANY_NETWORK) { true }
+        mCallback.expect<Available>(network = ANY_NETWORK, errorMsg = "error")
+        mCallback.expect<Available>(network = ANY_NETWORK, timeoutMs = 10)
+
+        mCallback.expect<Available> { true }
+        mCallback.expect<Available>(errorMsg = "error")
+        mCallback.expect<Available>(timeoutMs = 10)
+        mCallback.expect<Available>(network = ANY_NETWORK)
+        mCallback.expect<Available>()
+
+        mCallback.expect<Available>(hn, 10, "error") { true }
+        mCallback.expect<Available>(network = hn, errorMsg = "error") { true }
+        mCallback.expect<Available>(network = hn, timeoutMs = 10) { true }
+        mCallback.expect<Available>(network = hn, timeoutMs = 10, errorMsg = "error")
+
+        mCallback.expect<Available>(network = hn) { true }
+        mCallback.expect<Available>(network = hn, errorMsg = "error")
+        mCallback.expect<Available>(network = hn, timeoutMs = 10)
+
+        mCallback.expect<Available>(network = hn)
+    }
+
+    @Test
+    fun testPoll() {
+        assertNull(mCallback.poll(SHORT_TIMEOUT_MS))
+        TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1,
+                threadTransform = { cb -> cb.createLinkedCopy() }, spec = """
+            sleep; onAvailable(133)    | poll(2) = Available(133) time 1..4
+                                       | poll(1) = null
+            onCapabilitiesChanged(108) | poll(1) = CapabilitiesChanged(108) time 0..3
+            onBlockedStatus(199)       | poll(1) = BlockedStatus(199) time 0..3
+        """)
+    }
+
+    @Test
+    fun testPollOrThrow() {
+        assertFails { mCallback.pollOrThrow(SHORT_TIMEOUT_MS) }
+        TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1,
+                threadTransform = { cb -> cb.createLinkedCopy() }, spec = """
+            sleep; onAvailable(133)    | pollOrThrow(2) = Available(133) time 1..4
+                                       | pollOrThrow(1) fails
+            onCapabilitiesChanged(108) | pollOrThrow(1) = CapabilitiesChanged(108) time 0..3
+            onBlockedStatus(199)       | pollOrThrow(1) = BlockedStatus(199) time 0..3
+        """)
+    }
+
+    @Test
+    fun testEventuallyExpect() {
+        // TODO: Current test does not verify the inline one. Also verify the behavior after
+        // aligning two eventuallyExpect()
+        val net1 = Network(100)
+        val net2 = Network(101)
+        mCallback.onAvailable(net1)
+        mCallback.onCapabilitiesChanged(net1, NetworkCapabilities())
+        mCallback.onLinkPropertiesChanged(net1, LinkProperties())
+        mCallback.eventuallyExpect(LINK_PROPERTIES_CHANGED) {
+            net1.equals(it.network)
+        }
+        // No further new callback. Expect no callback.
+        assertFails { mCallback.eventuallyExpect(LINK_PROPERTIES_CHANGED, SHORT_TIMEOUT_MS) }
+
+        // Verify no predicate set.
+        mCallback.onAvailable(net2)
+        mCallback.onLinkPropertiesChanged(net2, LinkProperties())
+        mCallback.onBlockedStatusChanged(net1, false)
+        mCallback.eventuallyExpect(BLOCKED_STATUS) { net1.equals(it.network) }
+        // Verify no callback received if the callback does not happen.
+        assertFails { mCallback.eventuallyExpect(LOSING, SHORT_TIMEOUT_MS) }
+    }
+
+    @Test
+    fun testEventuallyExpectOnMultiThreads() {
+        TNCInterpreter.interpretTestSpec(initial = mCallback, lineShift = 1,
+                threadTransform = { cb -> cb.createLinkedCopy() }, spec = """
+                onAvailable(100)                   | eventually(CapabilitiesChanged(100), 1) fails
+                sleep ; onCapabilitiesChanged(100) | eventually(CapabilitiesChanged(100), 3)
+                onAvailable(101) ; onBlockedStatus(101) | eventually(BlockedStatus(100), 2) fails
+                onSuspended(100) ; sleep ; onLost(100)  | eventually(Lost(100), 3)
+        """)
+    }
+}
+
+private object TNCInterpreter : ConcurrentInterpreter<TestableNetworkCallback>(interpretTable)
+
+val EntryList = CallbackEntry::class.sealedSubclasses.map { it.simpleName }.joinToString("|")
+private fun callbackEntryFromString(name: String): KClass<out CallbackEntry> {
+    return CallbackEntry::class.sealedSubclasses.first { it.simpleName == name }
+}
+
+@SuppressLint("NewApi") // Uses hidden APIs, which the linter would identify as missing APIs.
+private val interpretTable = listOf<InterpretMatcher<TestableNetworkCallback>>(
+    // Interpret "Available(xx)" as "call to onAvailable with netId xx", and likewise for
+    // all callback types. This is implemented above by enumerating the subclasses of
+    // CallbackEntry and reading their simpleName.
+    Regex("""(.*)\s+=\s+($EntryList)\((\d+)\)""") to { i, cb, t ->
+        val record = i.interpret(t.strArg(1), cb)
+        assertTrue(callbackEntryFromString(t.strArg(2)).isInstance(record))
+        // Strictly speaking testing for is CallbackEntry is useless as it's been tested above
+        // but the compiler can't figure things out from the isInstance call. It does understand
+        // from the assertTrue(is CallbackEntry) that this is true, which allows to access
+        // the 'network' member below.
+        assertTrue(record is CallbackEntry)
+        assertEquals(record.network.netId, t.intArg(3))
+    },
+    // Interpret "onAvailable(xx)" as calling "onAvailable" with a netId of xx, and likewise for
+    // all callback types. NetworkCapabilities and LinkProperties just get an empty object
+    // as their argument. Losing gets the default linger timer. Blocked gets false.
+    Regex("""on($EntryList)\((\d+)\)""") to { i, cb, t ->
+        val net = Network(t.intArg(2))
+        when (t.strArg(1)) {
+            "Available" -> cb.onAvailable(net)
+            // PreCheck not used in tests. Add it here if it becomes useful.
+            "CapabilitiesChanged" -> cb.onCapabilitiesChanged(net, NetworkCapabilities())
+            "LinkPropertiesChanged" -> cb.onLinkPropertiesChanged(net, LinkProperties())
+            "Suspended" -> cb.onNetworkSuspended(net)
+            "Resumed" -> cb.onNetworkResumed(net)
+            "Losing" -> cb.onLosing(net, DEFAULT_LINGER_DELAY_MS)
+            "Lost" -> cb.onLost(net)
+            "Unavailable" -> cb.onUnavailable()
+            "BlockedStatus" -> cb.onBlockedStatusChanged(net, false)
+            else -> fail("Unknown callback type")
+        }
+    },
+    Regex("""poll\((\d+)\)""") to { i, cb, t -> cb.poll(t.timeArg(1)) },
+    Regex("""pollOrThrow\((\d+)\)""") to { i, cb, t -> cb.pollOrThrow(t.timeArg(1)) },
+    // Interpret "eventually(Available(xx), timeout)" as calling eventuallyExpect that expects
+    // CallbackEntry.AVAILABLE with netId of xx within timeout*INTERPRET_TIME_UNIT timeout, and
+    // likewise for all callback types.
+    Regex("""eventually\(($EntryList)\((\d+)\),\s+(\d+)\)""") to { i, cb, t ->
+        val net = Network(t.intArg(2))
+        val timeout = t.timeArg(3)
+        when (t.strArg(1)) {
+            "Available" -> cb.eventuallyExpect(AVAILABLE, timeout) { net == it.network }
+            "Suspended" -> cb.eventuallyExpect(SUSPENDED, timeout) { net == it.network }
+            "Resumed" -> cb.eventuallyExpect(RESUMED, timeout) { net == it.network }
+            "Losing" -> cb.eventuallyExpect(LOSING, timeout) { net == it.network }
+            "Lost" -> cb.eventuallyExpect(LOST, timeout) { net == it.network }
+            "Unavailable" -> cb.eventuallyExpect(UNAVAILABLE, timeout) { net == it.network }
+            "BlockedStatus" -> cb.eventuallyExpect(BLOCKED_STATUS, timeout) { net == it.network }
+            "CapabilitiesChanged" ->
+                cb.eventuallyExpect(NETWORK_CAPS_UPDATED, timeout) { net == it.network }
+            "LinkPropertiesChanged" ->
+                cb.eventuallyExpect(LINK_PROPERTIES_CHANGED, timeout) { net == it.network }
+            else -> fail("Unknown callback type")
+        }
+    }
+)
diff --git a/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTestJava.java b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTestJava.java
new file mode 100644
index 0000000..4570d0a
--- /dev/null
+++ b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTestJava.java
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils;
+
+import static com.android.testutils.RecorderCallback.CallbackEntry.AVAILABLE;
+import static com.android.testutils.TestableNetworkCallbackKt.anyNetwork;
+
+import static org.junit.Assume.assumeTrue;
+
+import org.junit.Test;
+
+public class TestableNetworkCallbackTestJava {
+    @Test
+    void testAllExpectOverloads() {
+        // This test should never run, it only checks that all overloads exist and build
+        assumeTrue(false);
+        final TestableNetworkCallback callback = new TestableNetworkCallback();
+        TestableNetworkCallback.HasNetwork hn = TestableNetworkCallbackKt::anyNetwork;
+
+        // Method with all arguments (version that takes a Network)
+        callback.expect(AVAILABLE, anyNetwork(), 10, "error", cb -> true);
+
+        // Overloads omitting one argument. One line for omitting each argument, in positional
+        // order. Versions that take a Network.
+        callback.expect(AVAILABLE, 10, "error", cb -> true);
+        callback.expect(AVAILABLE, anyNetwork(), "error", cb -> true);
+        callback.expect(AVAILABLE, anyNetwork(), 10, cb -> true);
+        callback.expect(AVAILABLE, anyNetwork(), 10, "error");
+
+        // Overloads for omitting two arguments. One line for omitting each pair of arguments.
+        // Versions that take a Network.
+        callback.expect(AVAILABLE, "error", cb -> true);
+        callback.expect(AVAILABLE, 10, cb -> true);
+        callback.expect(AVAILABLE, 10, "error");
+        callback.expect(AVAILABLE, anyNetwork(), cb -> true);
+        callback.expect(AVAILABLE, anyNetwork(), "error");
+        callback.expect(AVAILABLE, anyNetwork(), 10);
+
+        // Overloads for omitting three arguments. One line for each remaining argument.
+        // Versions that take a Network.
+        callback.expect(AVAILABLE, cb -> true);
+        callback.expect(AVAILABLE, "error");
+        callback.expect(AVAILABLE, 10);
+        callback.expect(AVAILABLE, anyNetwork());
+
+        // Java overload for omitting all four arguments.
+        callback.expect(AVAILABLE);
+
+        // Same orders as above, but versions that take a HasNetwork. Except overloads that
+        // were already tested because they omitted the Network argument
+        callback.expect(AVAILABLE, hn, 10, "error", cb -> true);
+        callback.expect(AVAILABLE, hn, "error", cb -> true);
+        callback.expect(AVAILABLE, hn, 10, cb -> true);
+        callback.expect(AVAILABLE, hn, 10, "error");
+
+        callback.expect(AVAILABLE, hn, cb -> true);
+        callback.expect(AVAILABLE, hn, "error");
+        callback.expect(AVAILABLE, hn, 10);
+
+        callback.expect(AVAILABLE, hn);
+    }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt b/staticlibs/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt
index cbdc017..9e72f4b 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/ConcurrentInterpreter.kt
@@ -13,7 +13,7 @@
 typealias InterpretMatcher<T> = Pair<Regex, (ConcurrentInterpreter<T>, T, MatchResult) -> Any?>
 
 // The default unit of time for interpreted tests
-val INTERPRET_TIME_UNIT = 40L // ms
+const val INTERPRET_TIME_UNIT = 60L // ms
 
 /**
  * A small interpreter for testing parallel code.
@@ -40,10 +40,7 @@
  * Some expressions already exist by default and can be used by all interpreters. Refer to
  * getDefaultInstructions() below for a list and documentation.
  */
-open class ConcurrentInterpreter<T>(
-    localInterpretTable: List<InterpretMatcher<T>>,
-    val interpretTimeUnit: Long = INTERPRET_TIME_UNIT
-) {
+open class ConcurrentInterpreter<T>(localInterpretTable: List<InterpretMatcher<T>>) {
     private val interpretTable: List<InterpretMatcher<T>> =
             localInterpretTable + getDefaultInstructions()
     // The last time the thread became blocked, with base System.currentTimeMillis(). This should
@@ -211,7 +208,7 @@
     },
     // Interpret sleep. Optional argument for the count, in INTERPRET_TIME_UNIT units.
     Regex("""sleep(\((\d+)\))?""") to { i, t, r ->
-        SystemClock.sleep(if (r.strArg(2).isEmpty()) i.interpretTimeUnit else r.timeArg(2))
+        SystemClock.sleep(if (r.strArg(2).isEmpty()) INTERPRET_TIME_UNIT else r.timeArg(2))
     },
     Regex("""(.*)\s*fails""") to { i, t, r ->
         assertFails { i.interpret(r.strArg(1), t) }
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/ConnectUtil.kt b/staticlibs/testutils/devicetests/com/android/testutils/ConnectUtil.kt
index fc951d8..7b5ad01 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/ConnectUtil.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/ConnectUtil.kt
@@ -63,6 +63,7 @@
 
         try {
             val connInfo = wifiManager.connectionInfo
+            Log.d(TAG, "connInfo=" + connInfo)
             if (connInfo == null || connInfo.networkId == -1) {
                 clearWifiBlocklist()
                 val pfd = getInstrumentation().uiAutomation.executeShellCommand("svc wifi enable")
@@ -75,7 +76,7 @@
                     timeoutMs = WIFI_CONNECT_TIMEOUT_MS)
 
             assertNotNull(cb, "Could not connect to a wifi access point within " +
-                    "$WIFI_CONNECT_INTERVAL_MS ms. Check that the test device has a wifi network " +
+                    "$WIFI_CONNECT_TIMEOUT_MS ms. Check that the test device has a wifi network " +
                     "configured, and that the test access point is functioning properly.")
             return cb.network
         } finally {
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt b/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt
new file mode 100644
index 0000000..3d98cc3
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt
@@ -0,0 +1,176 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.Manifest.permission.READ_DEVICE_CONFIG
+import android.Manifest.permission.WRITE_DEVICE_CONFIG
+import android.provider.DeviceConfig
+import android.util.Log
+import com.android.modules.utils.build.SdkLevel
+import com.android.testutils.FunctionalUtils.ThrowingRunnable
+import org.junit.rules.TestRule
+import org.junit.runner.Description
+import org.junit.runners.model.Statement
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.Executor
+import java.util.concurrent.TimeUnit
+
+private val TAG = DeviceConfigRule::class.simpleName
+
+private const val TIMEOUT_MS = 20_000L
+
+/**
+ * A [TestRule] that helps set [DeviceConfig] for tests and clean up the test configuration
+ * automatically on teardown.
+ *
+ * The rule can also optionally retry tests when they fail following an external change of
+ * DeviceConfig before S; this typically happens because device config flags are synced while the
+ * test is running, and DisableConfigSyncTargetPreparer is only usable starting from S.
+ *
+ * @param retryCountBeforeSIfConfigChanged if > 0, when the test fails before S, check if
+ *        the configs that were set through this rule were changed, and retry the test
+ *        up to the specified number of times if yes.
+ */
+class DeviceConfigRule @JvmOverloads constructor(
+    val retryCountBeforeSIfConfigChanged: Int = 0
+) : TestRule {
+    // Maps (namespace, key) -> value
+    private val originalConfig = mutableMapOf<Pair<String, String>, String?>()
+    private val usedConfig = mutableMapOf<Pair<String, String>, String?>()
+
+    /**
+     * Actions to be run after cleanup of the config, for the current test only.
+     */
+    private val currentTestCleanupActions = mutableListOf<ThrowingRunnable>()
+
+    override fun apply(base: Statement, description: Description): Statement {
+        return TestValidationUrlStatement(base, description)
+    }
+
+    private inner class TestValidationUrlStatement(
+        private val base: Statement,
+        private val description: Description
+    ) : Statement() {
+        override fun evaluate() {
+            var retryCount = if (SdkLevel.isAtLeastS()) 1 else retryCountBeforeSIfConfigChanged + 1
+            while (retryCount > 0) {
+                retryCount--
+                tryTest {
+                    base.evaluate()
+                    // Can't use break/return out of a loop here because this is a tryTest lambda,
+                    // so set retryCount to exit instead
+                    retryCount = 0
+                }.catch<Throwable> { e -> // junit AssertionFailedError does not extend Exception
+                    if (retryCount == 0) throw e
+                    usedConfig.forEach { (key, value) ->
+                        val currentValue = runAsShell(READ_DEVICE_CONFIG) {
+                            DeviceConfig.getProperty(key.first, key.second)
+                        }
+                        if (currentValue != value) {
+                            Log.w(TAG, "Test failed with unexpected device config change, retrying")
+                            return@catch
+                        }
+                    }
+                    throw e
+                } cleanupStep {
+                    runAsShell(WRITE_DEVICE_CONFIG) {
+                        originalConfig.forEach { (key, value) ->
+                            DeviceConfig.setProperty(
+                                    key.first, key.second, value, false /* makeDefault */)
+                        }
+                    }
+                } cleanupStep {
+                    originalConfig.clear()
+                    usedConfig.clear()
+                } cleanup {
+                    // Fold all cleanup actions into cleanup steps of an empty tryTest, so they are
+                    // all run even if exceptions are thrown, and exceptions are reported properly.
+                    currentTestCleanupActions.fold(tryTest { }) {
+                        tryBlock, action -> tryBlock.cleanupStep { action.run() }
+                    }.cleanup {
+                        currentTestCleanupActions.clear()
+                    }
+                }
+            }
+        }
+    }
+
+    /**
+     * Set a configuration key/value. After the test case ends, it will be restored to the value it
+     * had when this method was first called.
+     */
+    fun setConfig(namespace: String, key: String, value: String?): String? {
+        Log.i(TAG, "Setting config \"$key\" to \"$value\"")
+        val readWritePermissions = arrayOf(READ_DEVICE_CONFIG, WRITE_DEVICE_CONFIG)
+
+        val keyPair = Pair(namespace, key)
+        val existingValue = runAsShell(*readWritePermissions) {
+            DeviceConfig.getProperty(namespace, key)
+        }
+        if (!originalConfig.containsKey(keyPair)) {
+            originalConfig[keyPair] = existingValue
+        }
+        usedConfig[keyPair] = value
+        if (existingValue == value) {
+            // Already the correct value. There may be a race if a change is already in flight,
+            // but if multiple threads update the config there is no way to fix that anyway.
+            Log.i(TAG, "\"$key\" already had value \"$value\"")
+            return value
+        }
+
+        val future = CompletableFuture<String>()
+        val listener = DeviceConfig.OnPropertiesChangedListener {
+            // The listener receives updates for any change to any key, so don't react to
+            // changes that do not affect the relevant key
+            if (!it.keyset.contains(key)) return@OnPropertiesChangedListener
+            // "null" means absent in DeviceConfig : there is no such thing as a present but
+            // null value, so the following works even if |value| is null.
+            if (it.getString(key, null) == value) {
+                future.complete(value)
+            }
+        }
+
+        return tryTest {
+            runAsShell(*readWritePermissions) {
+                DeviceConfig.addOnPropertiesChangedListener(
+                        DeviceConfig.NAMESPACE_CONNECTIVITY,
+                        inlineExecutor,
+                        listener)
+                DeviceConfig.setProperty(
+                        DeviceConfig.NAMESPACE_CONNECTIVITY,
+                        key,
+                        value,
+                        false /* makeDefault */)
+                // Don't drop the permission until the config is applied, just in case
+                future.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)
+            }.also {
+                Log.i(TAG, "Config \"$key\" successfully set to \"$value\"")
+            }
+        } cleanup {
+            DeviceConfig.removeOnPropertiesChangedListener(listener)
+        }
+    }
+
+    private val inlineExecutor get() = Executor { r -> r.run() }
+
+    /**
+     * Add an action to be run after config cleanup when the current test case ends.
+     */
+    fun runAfterNextCleanup(action: ThrowingRunnable) {
+        currentTestCleanupActions.add(action)
+    }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java b/staticlibs/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java
index ea89eda..ce55fdc 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DeviceInfoUtils.java
@@ -20,6 +20,7 @@
 import android.text.TextUtils;
 import android.util.Pair;
 
+import java.util.Objects;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 
@@ -128,7 +129,7 @@
         final Pair<Integer, Integer> v1 = getMajorMinorVersion(s1);
         final Pair<Integer, Integer> v2 = getMajorMinorVersion(s2);
 
-        if (v1.first == v2.first) {
+        if (Objects.equals(v1.first, v2.first)) {
             return Integer.compare(v1.second, v2.second);
         } else {
             return Integer.compare(v1.first, v2.first);
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt
index 5ae2439..b743b6c 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestNetworkTracker.kt
@@ -111,12 +111,12 @@
     val tnm: TestNetworkManager,
     val lp: LinkProperties?,
     setupTimeoutMs: Long
-) {
+) : TestableNetworkCallback.HasNetwork {
     private val cm = context.getSystemService(ConnectivityManager::class.java)
     private val binder = Binder()
 
     private val networkCallback: NetworkCallback
-    val network: Network
+    override val network: Network
     val testIface: TestNetworkInterface
 
     init {
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
index dffdbe8..b84f9a6 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
@@ -36,15 +36,12 @@
 import kotlin.reflect.KClass
 import kotlin.test.assertEquals
 import kotlin.test.assertNotNull
-import kotlin.test.assertTrue
 import kotlin.test.fail
 
 object NULL_NETWORK : Network(-1)
 object ANY_NETWORK : Network(-2)
 fun anyNetwork() = ANY_NETWORK
 
-private val Int.capabilityName get() = NetworkCapabilities.capabilityNameOf(this)
-
 open class RecorderCallback private constructor(
     private val backingRecord: ArrayTrackRecord<CallbackEntry>
 ) : NetworkCallback() {
@@ -168,16 +165,22 @@
     }
 }
 
-private const val DEFAULT_TIMEOUT = 200L // ms
+private const val DEFAULT_TIMEOUT = 30_000L // ms
+private const val DEFAULT_NO_CALLBACK_TIMEOUT = 200L // ms
 
 open class TestableNetworkCallback private constructor(
     src: TestableNetworkCallback?,
-    val defaultTimeoutMs: Long = DEFAULT_TIMEOUT
+    val defaultTimeoutMs: Long = DEFAULT_TIMEOUT,
+    val defaultNoCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT
 ) : RecorderCallback(src) {
     @JvmOverloads
-    constructor(timeoutMs: Long = DEFAULT_TIMEOUT): this(null, timeoutMs)
+    constructor(
+        timeoutMs: Long = DEFAULT_TIMEOUT,
+        noCallbackTimeoutMs: Long = DEFAULT_NO_CALLBACK_TIMEOUT
+    ): this(null, timeoutMs, noCallbackTimeoutMs)
 
-    fun createLinkedCopy() = TestableNetworkCallback(this, defaultTimeoutMs)
+    fun createLinkedCopy() = TestableNetworkCallback(
+            this, defaultTimeoutMs, defaultNoCallbackTimeoutMs)
 
     // The last available network, or null if any network was lost since the last call to
     // onAvailable. TODO : fix this by fixing the tests that rely on this behavior
@@ -187,14 +190,163 @@
             else -> null
         }
 
-    fun pollForNextCallback(timeoutMs: Long = defaultTimeoutMs): CallbackEntry {
-        return history.poll(timeoutMs) ?: fail("Did not receive callback after ${timeoutMs}ms")
-    }
+    /**
+     * Get the next callback or null if timeout.
+     *
+     * With no argument, this method waits out the default timeout. To wait forever, pass
+     * Long.MAX_VALUE.
+     */
+    @JvmOverloads
+    fun poll(timeoutMs: Long = defaultTimeoutMs): CallbackEntry? = history.poll(timeoutMs)
+
+    /**
+     * Get the next callback or throw if timeout.
+     *
+     * With no argument, this method waits out the default timeout. To wait forever, pass
+     * Long.MAX_VALUE.
+     */
+    @JvmOverloads
+    fun pollOrThrow(
+        timeoutMs: Long = defaultTimeoutMs,
+        errorMsg: String = "Did not receive callback after $timeoutMs"
+    ): CallbackEntry = poll(timeoutMs) ?: fail(errorMsg)
+
+    /*****
+     * expect family of methods.
+     * These methods fetch the next callback and assert it matches the conditions : type,
+     * passed predicate. If no callback is received within the timeout, these methods fail.
+     */
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: Network = ANY_NETWORK,
+        timeoutMs: Long = defaultTimeoutMs,
+        errorMsg: String? = null,
+        test: (T) -> Boolean = { true }
+    ) = expect<CallbackEntry>(network, timeoutMs, errorMsg) {
+        test(it as? T ?: fail("Expected callback ${type.simpleName}, got $it"))
+    } as T
+
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: HasNetwork,
+        timeoutMs: Long = defaultTimeoutMs,
+        errorMsg: String? = null,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, network.network, timeoutMs, errorMsg, test)
+
+    // Java needs an explicit overload to let it omit arguments in the middle, so define these
+    // here. Note that @JvmOverloads give us the versions without the last arguments too, so
+    // there is no need to explicitly define versions without the test predicate.
+    // Without |network|
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        timeoutMs: Long,
+        errorMsg: String?,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, ANY_NETWORK, timeoutMs, errorMsg, test)
+
+    // Without |timeout|, in Network and HasNetwork versions
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: Network,
+        errorMsg: String?,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, network, defaultTimeoutMs, errorMsg, test)
+
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: HasNetwork,
+        errorMsg: String?,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, network.network, defaultTimeoutMs, errorMsg, test)
+
+    // Without |errorMsg|, in Network and HasNetwork versions
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: Network,
+        timeoutMs: Long,
+        test: (T) -> Boolean
+    ) = expect(type, network, timeoutMs, null, test)
+
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: HasNetwork,
+        timeoutMs: Long,
+        test: (T) -> Boolean
+    ) = expect(type, network.network, timeoutMs, null, test)
+
+    // Without |network| or |timeout|
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        errorMsg: String?,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, ANY_NETWORK, defaultTimeoutMs, errorMsg, test)
+
+    // Without |network| or |errorMsg|
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        timeoutMs: Long,
+        test: (T) -> Boolean = { true }
+    ) = expect(type, ANY_NETWORK, timeoutMs, null, test)
+
+    // Without |timeout| or |errorMsg|, in Network and HasNetwork versions
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: Network,
+        test: (T) -> Boolean
+    ) = expect(type, network, defaultTimeoutMs, null, test)
+
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        network: HasNetwork,
+        test: (T) -> Boolean
+    ) = expect(type, network.network, defaultTimeoutMs, null, test)
+
+    // Without |network| or |timeout| or |errorMsg|
+    @JvmOverloads
+    fun <T : CallbackEntry> expect(
+        type: KClass<T>,
+        test: (T) -> Boolean
+    ) = expect(type, ANY_NETWORK, defaultTimeoutMs, null, test)
+
+    // Kotlin reified versions. Don't call methods above, or the predicate would need to be noinline
+    inline fun <reified T : CallbackEntry> expect(
+        network: Network = ANY_NETWORK,
+        timeoutMs: Long = defaultTimeoutMs,
+        errorMsg: String? = null,
+        test: (T) -> Boolean = { true }
+    ) = pollOrThrow(timeoutMs).also {
+        if (it !is T) fail("Expected callback ${T::class.simpleName}, got $it")
+        if (ANY_NETWORK !== network && it.network != network) {
+            fail("Expected network $network for callback : $it")
+        }
+        if (!test(it)) {
+            fail("${errorMsg ?: "Callback doesn't match predicate"} : $it")
+        }
+    } as T
+
+    inline fun <reified T : CallbackEntry> expect(
+        network: HasNetwork,
+        timeoutMs: Long = defaultTimeoutMs,
+        errorMsg: String? = null,
+        test: (T) -> Boolean = { true }
+    ) = expect(network.network, timeoutMs, errorMsg, test)
 
     // Make open for use in ConnectivityServiceTest which is the only one knowing its handlers.
     // TODO : remove the necessity to overload this, remove the open qualifier, and give a
     // default argument to assertNoCallback instead, possibly with @JvmOverloads if necessary.
-    open fun assertNoCallback() = assertNoCallback(defaultTimeoutMs)
+    open fun assertNoCallback() = assertNoCallback(defaultNoCallbackTimeoutMs)
 
     fun assertNoCallback(timeoutMs: Long) {
         val cb = history.poll(timeoutMs)
@@ -202,7 +354,7 @@
     }
 
     fun assertNoCallbackThat(
-        timeoutMs: Long = defaultTimeoutMs,
+        timeoutMs: Long = defaultNoCallbackTimeoutMs,
         valid: (CallbackEntry) -> Boolean
     ) {
         val cb = history.poll(timeoutMs) { valid(it) }.let {
@@ -210,19 +362,6 @@
         }
     }
 
-    // Expects a callback of the specified type on the specified network within the timeout.
-    // If no callback arrives, or a different callback arrives, fail. Returns the callback.
-    inline fun <reified T : CallbackEntry> expectCallback(
-        network: Network = ANY_NETWORK,
-        timeoutMs: Long = defaultTimeoutMs
-    ): T = pollForNextCallback(timeoutMs).let {
-        if (it !is T || (ANY_NETWORK !== network && it.network != network)) {
-            fail("Unexpected callback : $it, expected ${T::class} with Network[$network]")
-        } else {
-            it
-        }
-    }
-
     // Expects a callback of the specified type matching the predicate within the timeout.
     // Any callback that doesn't match the predicate will be skipped. Fails only if
     // no matching callback is received within the timeout.
@@ -237,8 +376,17 @@
     fun <T : CallbackEntry> eventuallyExpect(
         type: KClass<T>,
         timeoutMs: Long = defaultTimeoutMs,
-        predicate: (T: CallbackEntry) -> Boolean = { true }
-    ) = history.poll(timeoutMs) { type.java.isInstance(it) && predicate(it) }.also {
+        predicate: (cb: T) -> Boolean = { true }
+    ) = history.poll(timeoutMs) { type.java.isInstance(it) && predicate(it as T) }.also {
+        assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms")
+    } as T
+
+    fun <T : CallbackEntry> eventuallyExpect(
+        type: KClass<T>,
+        timeoutMs: Long = defaultTimeoutMs,
+        from: Int = mark,
+        predicate: (cb: T) -> Boolean = { true }
+    ) = history.poll(timeoutMs, from) { type.java.isInstance(it) && predicate(it as T) }.also {
         assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms")
     } as T
 
@@ -249,30 +397,19 @@
         crossinline predicate: (T) -> Boolean = { true }
     ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T?
 
-    fun expectCallbackThat(
-        timeoutMs: Long = defaultTimeoutMs,
-        valid: (CallbackEntry) -> Boolean
-    ) = pollForNextCallback(timeoutMs).also { assertTrue(valid(it), "Unexpected callback : $it") }
-
-    fun expectCapabilitiesThat(
+    inline fun expectCapabilitiesThat(
         net: Network,
         tmt: Long = defaultTimeoutMs,
         valid: (NetworkCapabilities) -> Boolean
-    ): CapabilitiesChanged {
-        return expectCallback<CapabilitiesChanged>(net, tmt).also {
-            assertTrue(valid(it.caps), "Capabilities don't match expectations ${it.caps}")
-        }
-    }
+    ): CapabilitiesChanged =
+            expect(net, tmt, "Capabilities don't match expectations") { valid(it.caps) }
 
-    fun expectLinkPropertiesThat(
+    inline fun expectLinkPropertiesThat(
         net: Network,
         tmt: Long = defaultTimeoutMs,
         valid: (LinkProperties) -> Boolean
-    ): LinkPropertiesChanged {
-        return expectCallback<LinkPropertiesChanged>(net, tmt).also {
-            assertTrue(valid(it.lp), "LinkProperties don't match expectations ${it.lp}")
-        }
-    }
+    ): LinkPropertiesChanged =
+            expect(net, tmt, "LinkProperties don't match expectations") { valid(it.lp) }
 
     // Expects onAvailable and the callbacks that follow it. These are:
     // - onSuspended, iff the network was suspended when the callbacks fire.
@@ -313,16 +450,16 @@
         validated: Boolean?,
         tmt: Long
     ) {
-        expectCallback<Available>(net, tmt)
+        expect<Available>(net, tmt)
         if (suspended) {
-            expectCallback<Suspended>(net, tmt)
+            expect<Suspended>(net, tmt)
         }
         expectCapabilitiesThat(net, tmt) {
             validated == null || validated == it.hasCapability(
                 NET_CAPABILITY_VALIDATED
             )
         }
-        expectCallback<LinkPropertiesChanged>(net, tmt)
+        expect<LinkPropertiesChanged>(net, tmt)
     }
 
     // Backward compatibility for existing Java code. Use named arguments instead and remove all
@@ -333,17 +470,15 @@
         tmt: Long = defaultTimeoutMs
     ) = expectAvailableCallbacks(net, suspended = true, validated = validated, tmt = tmt)
 
-    fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) {
-        expectCallback<BlockedStatus>(net, tmt).also {
-            assertEquals(blocked, it.blocked, "Unexpected blocked status ${it.blocked}")
-        }
-    }
+    fun expectBlockedStatusCallback(blocked: Boolean, net: Network, tmt: Long = defaultTimeoutMs) =
+            expect<BlockedStatus>(net, tmt, "Unexpected blocked status") {
+                it.blocked == blocked
+            }
 
-    fun expectBlockedStatusCallback(blocked: Int, net: Network, tmt: Long = defaultTimeoutMs) {
-        expectCallback<BlockedStatusInt>(net, tmt).also {
-            assertEquals(blocked, it.blocked, "Unexpected blocked status ${it.blocked}")
-        }
-    }
+    fun expectBlockedStatusCallback(blocked: Int, net: Network, tmt: Long = defaultTimeoutMs) =
+            expect<BlockedStatusInt>(net, tmt, "Unexpected blocked status") {
+                it.blocked == blocked
+            }
 
     // Expects the available callbacks (where the onCapabilitiesChanged must contain the
     // VALIDATED capability), plus another onCapabilitiesChanged which is identical to the
@@ -353,7 +488,7 @@
         val mark = history.mark
         expectAvailableCallbacks(net, tmt = tmt)
         val firstCaps = history.poll(tmt, mark) { it is CapabilitiesChanged }
-        assertEquals(firstCaps, expectCallback<CapabilitiesChanged>(net, tmt))
+        assertEquals(firstCaps, expect<CapabilitiesChanged>(net, tmt))
     }
 
     // Expects the available callbacks where the onCapabilitiesChanged must not have validated,
@@ -381,26 +516,6 @@
         val network: Network
     }
 
-    @JvmOverloads
-    open fun <T : CallbackEntry> expectCallback(
-        type: KClass<T>,
-        n: Network?,
-        timeoutMs: Long = defaultTimeoutMs
-    ) = pollForNextCallback(timeoutMs).also {
-        val network = n ?: NULL_NETWORK
-        // TODO : remove this .java access if the tests ever use kotlin-reflect. At the time of
-        // this writing this would be the only use of this library in the tests.
-        assertTrue(type.java.isInstance(it) && (ANY_NETWORK === n || it.network == network),
-                "Unexpected callback : $it, expected ${type.java} with Network[$network]")
-    } as T
-
-    @JvmOverloads
-    open fun <T : CallbackEntry> expectCallback(
-        type: KClass<T>,
-        n: HasNetwork?,
-        timeoutMs: Long = defaultTimeoutMs
-    ) = expectCallback(type, n?.network, timeoutMs)
-
     fun expectAvailableCallbacks(
         n: HasNetwork,
         suspended: Boolean,