Merge "Add withCleanCallingIdentity with Supplier to module utils"
diff --git a/staticlibs/Android.bp b/staticlibs/Android.bp
index f4997a1..63c8b4e 100644
--- a/staticlibs/Android.bp
+++ b/staticlibs/Android.bp
@@ -70,6 +70,7 @@
   libs: [
       "androidx.annotation_annotation",
       "framework-annotations-lib",
+      "framework-configinfrastructure",
       "framework-connectivity.stubs.module_lib",
   ],
   lint: { strict_updatability_linting: true },
diff --git a/staticlibs/framework/com/android/net/module/util/CollectionUtils.java b/staticlibs/framework/com/android/net/module/util/CollectionUtils.java
index 7cac90d..f08880c 100644
--- a/staticlibs/framework/com/android/net/module/util/CollectionUtils.java
+++ b/staticlibs/framework/com/android/net/module/util/CollectionUtils.java
@@ -18,12 +18,15 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.util.ArrayMap;
+import android.util.Pair;
 import android.util.SparseArray;
 
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.List;
 import java.util.Objects;
+import java.util.function.Function;
 import java.util.function.Predicate;
 
 /**
@@ -291,4 +294,100 @@
             @NonNull final Predicate<? super T> condition) {
         return -1 != indexOf(haystack, condition);
     }
+
+    /**
+     * Standard map function, but returns a new modifiable ArrayList
+     *
+     * This returns a new list that contains, for each element of the source collection, its
+     * image through the passed transform.
+     * Elements in the source can be null if the transform accepts null inputs.
+     * Elements in the output can be null if the transform ever returns null.
+     * This function never returns null. If the source collection is empty, it returns the
+     * empty list.
+     * Contract : this method calls the transform function exactly once for each element in the
+     * list, in iteration order.
+     *
+     * @param source the source collection
+     * @param transform the function to transform the elements
+     * @param <T> type of source elements
+     * @param <R> type of destination elements
+     * @return an unmodifiable list of transformed elements
+     */
+    @NonNull
+    public static <T, R> ArrayList<R> map(@NonNull final Collection<T> source,
+            @NonNull final Function<? super T, ? extends R> transform) {
+        final ArrayList<R> dest = new ArrayList<>(source.size());
+        for (final T e : source) {
+            dest.add(transform.apply(e));
+        }
+        return dest;
+    }
+
+    /**
+     * Standard zip function, but returns a new modifiable ArrayList
+     *
+     * This returns a list of pairs containing, at each position, a pair of the element from the
+     * first list at that index and the element from the second list at that index.
+     * Both lists must be the same size. They may contain null.
+     *
+     * The easiest way to visualize what's happening is to think of two lists being laid out next
+     * to each other and stitched together with a zipper.
+     *
+     * Contract : this method will read each element of each list exactly once, in some unspecified
+     * order. If it throws, it will not read any element.
+     *
+     * @param first the first list of elements
+     * @param second the second list of elements
+     * @param <T> the type of first elements
+     * @param <R> the type of second elements
+     * @return the zipped list
+     */
+    @NonNull
+    public static <T, R> ArrayList<Pair<T, R>> zip(@NonNull final List<T> first,
+            @NonNull final List<R> second) {
+        final int size = first.size();
+        if (size != second.size()) {
+            throw new IllegalArgumentException("zip : collections must be the same size");
+        }
+        final ArrayList<Pair<T, R>> dest = new ArrayList<>(size);
+        for (int i = 0; i < size; ++i) {
+            dest.add(new Pair<>(first.get(i), second.get(i)));
+        }
+        return dest;
+    }
+
+    /**
+     * Returns a new ArrayMap that associates each key with the value at the same index.
+     *
+     * Both lists must be the same size.
+     * Both keys and values may contain null.
+     * Keys may not contain the same value twice.
+     *
+     * Contract : this method will read each element of each list exactly once, but does not
+     * specify the order, except if it throws in which case the number of reads is undefined.
+     *
+     * @param keys The list of keys
+     * @param values The list of values
+     * @param <T> The type of keys
+     * @param <R> The type of values
+     * @return The associated map
+     */
+    @NonNull
+    public static <T, R> ArrayMap<T, R> assoc(
+            @NonNull final List<T> keys, @NonNull final List<R> values) {
+        final int size = keys.size();
+        if (size != values.size()) {
+            throw new IllegalArgumentException("assoc : collections must be the same size");
+        }
+        final ArrayMap<T, R> dest = new ArrayMap<>(size);
+        for (int i = 0; i < size; ++i) {
+            final T key = keys.get(i);
+            if (dest.containsKey(key)) {
+                throw new IllegalArgumentException(
+                        "assoc : keys may not contain the same value twice");
+            }
+            dest.put(key, values.get(i));
+        }
+        return dest;
+    }
 }
diff --git a/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h b/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
index c652c76..ea56593 100644
--- a/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
+++ b/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
@@ -137,6 +137,14 @@
         BPF_FUNC_map_update_elem;
 static int (*bpf_map_delete_elem_unsafe)(const struct bpf_map_def* map,
                                          const void* key) = (void*)BPF_FUNC_map_delete_elem;
+static int (*bpf_ringbuf_output_unsafe)(const struct bpf_map_def* ringbuf,
+                                        const void* data, __u64 size, __u64 flags) = (void*)
+        BPF_FUNC_ringbuf_output;
+static void* (*bpf_ringbuf_reserve_unsafe)(const struct bpf_map_def* ringbuf,
+                                           __u64 size, __u64 flags) = (void*)
+        BPF_FUNC_ringbuf_reserve;
+static void (*bpf_ringbuf_submit_unsafe)(const void* data, __u64 flags) = (void*)
+        BPF_FUNC_ringbuf_submit;
 
 #define BPF_ANNOTATE_KV_PAIR(name, type_key, type_val)  \
         struct ____btf_map_##name {                     \
@@ -147,6 +155,50 @@
         __attribute__ ((section(".maps." #name), used)) \
                 ____btf_map_##name = { }
 
+#define DEFINE_BPF_MAP_BASE(the_map, TYPE, keysize, valuesize, num_entries, \
+                            usr, grp, md, selinux, pindir, share, minkver,  \
+                            maxkver)                                        \
+  const struct bpf_map_def SECTION("maps") the_map = {                      \
+      .type = BPF_MAP_TYPE_##TYPE,                                          \
+      .key_size = (keysize),                                                \
+      .value_size = (valuesize),                                            \
+      .max_entries = (num_entries),                                         \
+      .map_flags = 0,                                                       \
+      .uid = (usr),                                                         \
+      .gid = (grp),                                                         \
+      .mode = (md),                                                         \
+      .bpfloader_min_ver = DEFAULT_BPFLOADER_MIN_VER,                       \
+      .bpfloader_max_ver = DEFAULT_BPFLOADER_MAX_VER,                       \
+      .min_kver = (minkver),                                                \
+      .max_kver = (maxkver),                                                \
+      .selinux_context = (selinux),                                         \
+      .pin_subdir = (pindir),                                               \
+      .shared = (share),                                                    \
+  };
+
+// Type safe macro to declare a ring buffer and related output functions.
+// Compatibility:
+// * BPF ring buffers are only available kernels 5.8 and above. Any program
+//   accessing the ring buffer should set a program level min_kver >= 5.8.
+// * The definition below sets a map min_kver of 5.8 which requires targeting
+//   a BPFLOADER_MIN_VER >= BPFLOADER_S_VERSION.
+#define DEFINE_BPF_RINGBUF(the_map, ValueType, size_bytes, usr, grp, md, \
+                           selinux, pindir, share)                       \
+  DEFINE_BPF_MAP_BASE(the_map, RINGBUF, 0, 0, size_bytes, usr, grp, md,  \
+                      selinux, pindir, share, KVER(5, 8, 0), KVER_INF);  \
+  static inline __always_inline __unused int bpf_##the_map##_output(    \
+      const ValueType* v) {                                              \
+    return bpf_ringbuf_output_unsafe(&the_map, v, sizeof(*v), 0);        \
+  }                                                                      \
+  static inline __always_inline __unused                                 \
+      ValueType* bpf_##the_map##_reserve() {                             \
+    return bpf_ringbuf_reserve_unsafe(&the_map, sizeof(ValueType), 0);   \
+  }                                                                      \
+  static inline __always_inline __unused void bpf_##the_map##_submit(    \
+      const ValueType* v) {                                              \
+    bpf_ringbuf_submit_unsafe(v, 0);                                     \
+  }
+
 /* There exist buggy kernels with pre-T OS, that due to
  * kernel patch "[ALPS05162612] bpf: fix ubsan error"
  * do not support userspace writes into non-zero index of bpf map arrays.
@@ -167,23 +219,9 @@
 /* type safe macro to declare a map and related accessor functions */
 #define DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md,         \
                            selinux, pindir, share)                                               \
-    const struct bpf_map_def SECTION("maps") the_map = {                                         \
-            .type = BPF_MAP_TYPE_##TYPE,                                                         \
-            .key_size = sizeof(KeyType),                                                         \
-            .value_size = sizeof(ValueType),                                                     \
-            .max_entries = (num_entries),                                                        \
-            .map_flags = 0,                                                                      \
-            .uid = (usr),                                                                        \
-            .gid = (grp),                                                                        \
-            .mode = (md),                                                                        \
-            .bpfloader_min_ver = DEFAULT_BPFLOADER_MIN_VER,                                      \
-            .bpfloader_max_ver = DEFAULT_BPFLOADER_MAX_VER,                                      \
-            .min_kver = KVER_NONE,                                                               \
-            .max_kver = KVER_INF,                                                                \
-            .selinux_context = selinux,                                                          \
-            .pin_subdir = pindir,                                                                \
-            .shared = share,                                                                     \
-    };                                                                                           \
+  DEFINE_BPF_MAP_BASE(the_map, TYPE, sizeof(KeyType), sizeof(ValueType),                         \
+                      num_entries, usr, grp, md, selinux, pindir, share,                         \
+                      KVER_NONE, KVER_INF);                                                      \
     BPF_MAP_ASSERT_OK(BPF_MAP_TYPE_##TYPE, (num_entries), (md));                                 \
     BPF_ANNOTATE_KV_PAIR(the_map, KeyType, ValueType);                                           \
                                                                                                  \
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/CollectionUtilsTest.kt b/staticlibs/tests/unit/src/com/android/net/module/util/CollectionUtilsTest.kt
index 0f00d0b..9fb025b 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/CollectionUtilsTest.kt
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/CollectionUtilsTest.kt
@@ -22,6 +22,7 @@
 import org.junit.Test
 import org.junit.runner.RunWith
 import kotlin.test.assertEquals
+import kotlin.test.assertFailsWith
 import kotlin.test.assertFalse
 import kotlin.test.assertNull
 import kotlin.test.assertSame
@@ -133,4 +134,52 @@
         assertSame(CollectionUtils.findLast(listMulti) { it == "E" }, listMulti[7])
         assertNull(CollectionUtils.findLast(listMulti) { it == "F" })
     }
+
+    @Test
+    fun testMap() {
+        val listAE = listOf("A", "B", "C", "D", "E", null)
+        assertEquals(listAE.map { "-$it-" }, CollectionUtils.map(listAE) { "-$it-" })
+    }
+
+    @Test
+    fun testZip() {
+        val listAE = listOf("A", "B", "C", "D", "E")
+        val list15 = listOf(1, 2, 3, 4, 5)
+        // Normal #zip returns kotlin.Pair, not android.util.Pair
+        assertEquals(list15.zip(listAE).map { android.util.Pair(it.first, it.second) },
+                CollectionUtils.zip(list15, listAE))
+        val listNull = listOf("A", null, "B", "C", "D")
+        assertEquals(list15.zip(listNull).map { android.util.Pair(it.first, it.second) },
+                CollectionUtils.zip(list15, listNull))
+        assertEquals(emptyList<android.util.Pair<Int, Int>>(),
+                CollectionUtils.zip(emptyList<Int>(), emptyList<Int>()))
+        assertFailsWith<IllegalArgumentException> {
+            // Different size
+            CollectionUtils.zip(listOf(1, 2), list15)
+        }
+    }
+
+    @Test
+    fun testAssoc() {
+        val listADA = listOf("A", "B", "C", "D", "A")
+        val list15 = listOf(1, 2, 3, 4, 5)
+        assertEquals(list15.zip(listADA).toMap(), CollectionUtils.assoc(list15, listADA))
+
+        // Null key is fine
+        val assoc = CollectionUtils.assoc(listOf(1, 2, null), listOf("A", "B", "C"))
+        assertEquals("C", assoc[null])
+
+        assertFailsWith<IllegalArgumentException> {
+            // Same key multiple times
+            CollectionUtils.assoc(listOf("A", "B", "A"), listOf(1, 2, 3))
+        }
+        assertFailsWith<IllegalArgumentException> {
+            // Same key multiple times, but it's null
+            CollectionUtils.assoc(listOf(null, "B", null), listOf(1, 2, 3))
+        }
+        assertFailsWith<IllegalArgumentException> {
+            // Different size
+            CollectionUtils.assoc(listOf(1, 2), list15)
+        }
+    }
 }
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt b/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt
new file mode 100644
index 0000000..3d98cc3
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/DeviceConfigRule.kt
@@ -0,0 +1,176 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils
+
+import android.Manifest.permission.READ_DEVICE_CONFIG
+import android.Manifest.permission.WRITE_DEVICE_CONFIG
+import android.provider.DeviceConfig
+import android.util.Log
+import com.android.modules.utils.build.SdkLevel
+import com.android.testutils.FunctionalUtils.ThrowingRunnable
+import org.junit.rules.TestRule
+import org.junit.runner.Description
+import org.junit.runners.model.Statement
+import java.util.concurrent.CompletableFuture
+import java.util.concurrent.Executor
+import java.util.concurrent.TimeUnit
+
+private val TAG = DeviceConfigRule::class.simpleName
+
+private const val TIMEOUT_MS = 20_000L
+
+/**
+ * A [TestRule] that helps set [DeviceConfig] for tests and clean up the test configuration
+ * automatically on teardown.
+ *
+ * The rule can also optionally retry tests when they fail following an external change of
+ * DeviceConfig before S; this typically happens because device config flags are synced while the
+ * test is running, and DisableConfigSyncTargetPreparer is only usable starting from S.
+ *
+ * @param retryCountBeforeSIfConfigChanged if > 0, when the test fails before S, check if
+ *        the configs that were set through this rule were changed, and retry the test
+ *        up to the specified number of times if yes.
+ */
+class DeviceConfigRule @JvmOverloads constructor(
+    val retryCountBeforeSIfConfigChanged: Int = 0
+) : TestRule {
+    // Maps (namespace, key) -> value
+    private val originalConfig = mutableMapOf<Pair<String, String>, String?>()
+    private val usedConfig = mutableMapOf<Pair<String, String>, String?>()
+
+    /**
+     * Actions to be run after cleanup of the config, for the current test only.
+     */
+    private val currentTestCleanupActions = mutableListOf<ThrowingRunnable>()
+
+    override fun apply(base: Statement, description: Description): Statement {
+        return TestValidationUrlStatement(base, description)
+    }
+
+    private inner class TestValidationUrlStatement(
+        private val base: Statement,
+        private val description: Description
+    ) : Statement() {
+        override fun evaluate() {
+            var retryCount = if (SdkLevel.isAtLeastS()) 1 else retryCountBeforeSIfConfigChanged + 1
+            while (retryCount > 0) {
+                retryCount--
+                tryTest {
+                    base.evaluate()
+                    // Can't use break/return out of a loop here because this is a tryTest lambda,
+                    // so set retryCount to exit instead
+                    retryCount = 0
+                }.catch<Throwable> { e -> // junit AssertionFailedError does not extend Exception
+                    if (retryCount == 0) throw e
+                    usedConfig.forEach { (key, value) ->
+                        val currentValue = runAsShell(READ_DEVICE_CONFIG) {
+                            DeviceConfig.getProperty(key.first, key.second)
+                        }
+                        if (currentValue != value) {
+                            Log.w(TAG, "Test failed with unexpected device config change, retrying")
+                            return@catch
+                        }
+                    }
+                    throw e
+                } cleanupStep {
+                    runAsShell(WRITE_DEVICE_CONFIG) {
+                        originalConfig.forEach { (key, value) ->
+                            DeviceConfig.setProperty(
+                                    key.first, key.second, value, false /* makeDefault */)
+                        }
+                    }
+                } cleanupStep {
+                    originalConfig.clear()
+                    usedConfig.clear()
+                } cleanup {
+                    // Fold all cleanup actions into cleanup steps of an empty tryTest, so they are
+                    // all run even if exceptions are thrown, and exceptions are reported properly.
+                    currentTestCleanupActions.fold(tryTest { }) {
+                        tryBlock, action -> tryBlock.cleanupStep { action.run() }
+                    }.cleanup {
+                        currentTestCleanupActions.clear()
+                    }
+                }
+            }
+        }
+    }
+
+    /**
+     * Set a configuration key/value. After the test case ends, it will be restored to the value it
+     * had when this method was first called.
+     */
+    fun setConfig(namespace: String, key: String, value: String?): String? {
+        Log.i(TAG, "Setting config \"$key\" to \"$value\"")
+        val readWritePermissions = arrayOf(READ_DEVICE_CONFIG, WRITE_DEVICE_CONFIG)
+
+        val keyPair = Pair(namespace, key)
+        val existingValue = runAsShell(*readWritePermissions) {
+            DeviceConfig.getProperty(namespace, key)
+        }
+        if (!originalConfig.containsKey(keyPair)) {
+            originalConfig[keyPair] = existingValue
+        }
+        usedConfig[keyPair] = value
+        if (existingValue == value) {
+            // Already the correct value. There may be a race if a change is already in flight,
+            // but if multiple threads update the config there is no way to fix that anyway.
+            Log.i(TAG, "\"$key\" already had value \"$value\"")
+            return value
+        }
+
+        val future = CompletableFuture<String>()
+        val listener = DeviceConfig.OnPropertiesChangedListener {
+            // The listener receives updates for any change to any key, so don't react to
+            // changes that do not affect the relevant key
+            if (!it.keyset.contains(key)) return@OnPropertiesChangedListener
+            // "null" means absent in DeviceConfig : there is no such thing as a present but
+            // null value, so the following works even if |value| is null.
+            if (it.getString(key, null) == value) {
+                future.complete(value)
+            }
+        }
+
+        return tryTest {
+            runAsShell(*readWritePermissions) {
+                DeviceConfig.addOnPropertiesChangedListener(
+                        DeviceConfig.NAMESPACE_CONNECTIVITY,
+                        inlineExecutor,
+                        listener)
+                DeviceConfig.setProperty(
+                        DeviceConfig.NAMESPACE_CONNECTIVITY,
+                        key,
+                        value,
+                        false /* makeDefault */)
+                // Don't drop the permission until the config is applied, just in case
+                future.get(TIMEOUT_MS, TimeUnit.MILLISECONDS)
+            }.also {
+                Log.i(TAG, "Config \"$key\" successfully set to \"$value\"")
+            }
+        } cleanup {
+            DeviceConfig.removeOnPropertiesChangedListener(listener)
+        }
+    }
+
+    private val inlineExecutor get() = Executor { r -> r.run() }
+
+    /**
+     * Add an action to be run after config cleanup when the current test case ends.
+     */
+    fun runAfterNextCleanup(action: ThrowingRunnable) {
+        currentTestCleanupActions.add(action)
+    }
+}
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
index 406a179..b84f9a6 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
@@ -343,32 +343,6 @@
         test: (T) -> Boolean = { true }
     ) = expect(network.network, timeoutMs, errorMsg, test)
 
-    // TODO : remove all expectCallback and expectCallbackThat methods after all callers have been
-    // migrated to expect().
-    inline fun <reified T : CallbackEntry> expectCallback(
-        network: Network = ANY_NETWORK,
-        timeoutMs: Long = defaultTimeoutMs
-    ): T = expect(network, timeoutMs)
-
-    @JvmOverloads
-    open fun <T : CallbackEntry> expectCallback(
-        type: KClass<T>,
-        n: Network?,
-        timeoutMs: Long = defaultTimeoutMs
-    ) = expect(type, n ?: ANY_NETWORK, timeoutMs)
-
-    @JvmOverloads
-    open fun <T : CallbackEntry> expectCallback(
-        type: KClass<T>,
-        n: HasNetwork?,
-        timeoutMs: Long = defaultTimeoutMs
-    ) = expect(type, n?.network ?: ANY_NETWORK, timeoutMs)
-
-    fun expectCallbackThat(
-        timeoutMs: Long = defaultTimeoutMs,
-        valid: (CallbackEntry) -> Boolean
-    ) = expect(timeoutMs = timeoutMs, test = valid)
-
     // Make open for use in ConnectivityServiceTest which is the only one knowing its handlers.
     // TODO : remove the necessity to overload this, remove the open qualifier, and give a
     // default argument to assertNoCallback instead, possibly with @JvmOverloads if necessary.
@@ -402,8 +376,17 @@
     fun <T : CallbackEntry> eventuallyExpect(
         type: KClass<T>,
         timeoutMs: Long = defaultTimeoutMs,
-        predicate: (T: CallbackEntry) -> Boolean = { true }
-    ) = history.poll(timeoutMs) { type.java.isInstance(it) && predicate(it) }.also {
+        predicate: (cb: T) -> Boolean = { true }
+    ) = history.poll(timeoutMs) { type.java.isInstance(it) && predicate(it as T) }.also {
+        assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms")
+    } as T
+
+    fun <T : CallbackEntry> eventuallyExpect(
+        type: KClass<T>,
+        timeoutMs: Long = defaultTimeoutMs,
+        from: Int = mark,
+        predicate: (cb: T) -> Boolean = { true }
+    ) = history.poll(timeoutMs, from) { type.java.isInstance(it) && predicate(it as T) }.also {
         assertNotNull(it, "Callback ${type.java} not received within ${timeoutMs}ms")
     } as T