Merge "Add OS access interfaces and test impl for wear tethering"
diff --git a/staticlibs/framework/com/android/net/module/util/BitUtils.java b/staticlibs/framework/com/android/net/module/util/BitUtils.java
index 2b32e86..3062d8c 100644
--- a/staticlibs/framework/com/android/net/module/util/BitUtils.java
+++ b/staticlibs/framework/com/android/net/module/util/BitUtils.java
@@ -17,6 +17,7 @@
package com.android.net.module.util;
import android.annotation.NonNull;
+import android.annotation.Nullable;
/**
* @hide
@@ -107,4 +108,33 @@
++bitPos;
}
}
+
+ /**
+ * Returns a short but human-readable string of updates between an old and a new bit fields.
+ *
+ * @param oldVal the old bit field to diff from
+ * @param newVal the new bit field to diff to
+ * @return a string fit for logging differences, or null if no differences.
+ * this method cannot return the empty string.
+ */
+ @Nullable
+ public static String describeDifferences(final long oldVal, final long newVal,
+ @NonNull final NameOf nameFetcher) {
+ final long changed = oldVal ^ newVal;
+ if (0 == changed) return null;
+ // If the control reaches here, there are changes (additions, removals, or both) so
+ // the code below is guaranteed to add something to the string and can't return "".
+ final long removed = oldVal & changed;
+ final long added = newVal & changed;
+ final StringBuilder sb = new StringBuilder();
+ if (0 != removed) {
+ sb.append("-");
+ appendStringRepresentationOfBitMaskToStringBuilder(sb, removed, nameFetcher, "-");
+ }
+ if (0 != added) {
+ sb.append("+");
+ appendStringRepresentationOfBitMaskToStringBuilder(sb, added, nameFetcher, "+");
+ }
+ return sb.toString();
+ }
}
diff --git a/staticlibs/native/bpf_headers/BpfRingbufTest.cpp b/staticlibs/native/bpf_headers/BpfRingbufTest.cpp
index 4a45a93..d23afae 100644
--- a/staticlibs/native/bpf_headers/BpfRingbufTest.cpp
+++ b/staticlibs/native/bpf_headers/BpfRingbufTest.cpp
@@ -23,14 +23,20 @@
#include <unistd.h>
#include "BpfSyscallWrappers.h"
+#include "bpf/BpfRingbuf.h"
#include "bpf/BpfUtils.h"
+#define TEST_RINGBUF_MAGIC_NUM 12345
+
namespace android {
namespace bpf {
using ::android::base::testing::HasError;
using ::android::base::testing::HasValue;
-using ::android::base::testing::WithMessage;
+using ::android::base::testing::WithCode;
+using ::testing::AllOf;
+using ::testing::Gt;
using ::testing::HasSubstr;
+using ::testing::Lt;
class BpfRingbufTest : public ::testing::Test {
protected:
@@ -40,8 +46,11 @@
void SetUp() {
if (!android::bpf::isAtLeastKernelVersion(5, 8, 0)) {
- GTEST_SKIP() << "BPF ring buffers not supported";
- return;
+ GTEST_SKIP() << "BPF ring buffers not supported below 5.8";
+ }
+
+ if (sizeof(unsigned long) != 8) {
+ GTEST_SKIP() << "BPF ring buffers not supported on 32 bit arch";
}
errno = 0;
@@ -51,12 +60,82 @@
<< mProgPath << " was either not found or inaccessible.";
}
+ void RunProgram() {
+ char fake_skb[128] = {};
+ EXPECT_EQ(runProgram(mProgram, fake_skb, sizeof(fake_skb)), 0);
+ }
+
+ void RunTestN(int n) {
+ int run_count = 0;
+ uint64_t output = 0;
+ auto callback = [&](const uint64_t& value) {
+ output = value;
+ run_count++;
+ };
+
+ auto result = BpfRingbuf<uint64_t>::Create(mRingbufPath.c_str());
+ ASSERT_RESULT_OK(result);
+
+ for (int i = 0; i < n; i++) {
+ RunProgram();
+ }
+
+ EXPECT_THAT(result.value()->ConsumeAll(callback), HasValue(n));
+ EXPECT_EQ(output, TEST_RINGBUF_MAGIC_NUM);
+ EXPECT_EQ(run_count, n);
+ }
+
std::string mProgPath;
std::string mRingbufPath;
android::base::unique_fd mProgram;
};
-TEST_F(BpfRingbufTest, CheckSetUp) {}
+TEST_F(BpfRingbufTest, ConsumeSingle) { RunTestN(1); }
+TEST_F(BpfRingbufTest, ConsumeMultiple) { RunTestN(3); }
+
+TEST_F(BpfRingbufTest, FillAndWrap) {
+ int run_count = 0;
+ auto callback = [&](const uint64_t&) { run_count++; };
+
+ auto result = BpfRingbuf<uint64_t>::Create(mRingbufPath.c_str());
+ ASSERT_RESULT_OK(result);
+
+ // 4kb buffer with 16 byte payloads (8 byte data, 8 byte header) should fill
+ // after 255 iterations. Exceed that so that some events are dropped.
+ constexpr int iterations = 300;
+ for (int i = 0; i < iterations; i++) {
+ RunProgram();
+ }
+
+ // Some events were dropped, but consume all that succeeded.
+ EXPECT_THAT(result.value()->ConsumeAll(callback),
+ HasValue(AllOf(Gt(250), Lt(260))));
+ EXPECT_THAT(run_count, AllOf(Gt(250), Lt(260)));
+
+ // After consuming everything, we should be able to use the ring buffer again.
+ run_count = 0;
+ RunProgram();
+ EXPECT_THAT(result.value()->ConsumeAll(callback), HasValue(1));
+ EXPECT_EQ(run_count, 1);
+}
+
+TEST_F(BpfRingbufTest, WrongTypeSize) {
+ // The program under test writes 8-byte uint64_t values so a ringbuffer for
+ // 1-byte uint8_t values will fail to read from it. Note that the map_def does
+ // not specify the value size, so we fail on read, not creation.
+ auto result = BpfRingbuf<uint8_t>::Create(mRingbufPath.c_str());
+ ASSERT_RESULT_OK(result);
+
+ RunProgram();
+
+ EXPECT_THAT(result.value()->ConsumeAll([](const uint8_t&) {}),
+ HasError(WithCode(EMSGSIZE)));
+}
+
+TEST_F(BpfRingbufTest, InvalidPath) {
+ EXPECT_THAT(BpfRingbuf<int>::Create("/sys/fs/bpf/bad_path"),
+ HasError(WithCode(ENOENT)));
+}
} // namespace bpf
} // namespace android
diff --git a/staticlibs/native/bpf_headers/include/bpf/BpfRingbuf.h b/staticlibs/native/bpf_headers/include/bpf/BpfRingbuf.h
new file mode 100644
index 0000000..cac1e43
--- /dev/null
+++ b/staticlibs/native/bpf_headers/include/bpf/BpfRingbuf.h
@@ -0,0 +1,261 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <android-base/result.h>
+#include <android-base/unique_fd.h>
+#include <linux/bpf.h>
+#include <sys/mman.h>
+#include <utils/Log.h>
+
+#include "bpf/BpfUtils.h"
+
+namespace android {
+namespace bpf {
+
+// BpfRingbufBase contains the non-templated functionality of BPF ring buffers.
+class BpfRingbufBase {
+ public:
+ ~BpfRingbufBase() {
+ if (mConsumerPos) munmap(mConsumerPos, mConsumerSize);
+ if (mProducerPos) munmap(mProducerPos, mProducerSize);
+ mConsumerPos = nullptr;
+ mProducerPos = nullptr;
+ }
+
+ protected:
+ // Non-initializing constructor, used by Create.
+ BpfRingbufBase(size_t value_size) : mValueSize(value_size) {}
+
+ // Full construction that aborts on error (use Create/Init to handle errors).
+ BpfRingbufBase(const char* path, size_t value_size) : mValueSize(value_size) {
+ if (auto status = Init(path); !status.ok()) {
+ ALOGE("BpfRingbuf init failed: %s", status.error().message().c_str());
+ abort();
+ }
+ }
+
+ // Delete copy constructor (class owns raw pointers).
+ BpfRingbufBase(const BpfRingbufBase&) = delete;
+
+ // Initialize the base ringbuffer components. Must be called exactly once.
+ base::Result<void> Init(const char* path);
+
+ // Consumes all messages from the ring buffer, passing them to the callback.
+ base::Result<int> ConsumeAll(
+ const std::function<void(const void*)>& callback);
+
+ // Replicates c-style void* "byte-wise" pointer addition.
+ template <typename Ptr>
+ static Ptr pointerAddBytes(void* base, ssize_t offset_bytes) {
+ return reinterpret_cast<Ptr>(reinterpret_cast<char*>(base) + offset_bytes);
+ }
+
+ // Rounds len by clearing bitmask, adding header, and aligning to 8 bytes.
+ static uint32_t roundLength(uint32_t len) {
+ len &= ~(BPF_RINGBUF_BUSY_BIT | BPF_RINGBUF_DISCARD_BIT);
+ len += BPF_RINGBUF_HDR_SZ;
+ return (len + 7) & ~7;
+ }
+
+ const size_t mValueSize;
+
+ size_t mConsumerSize;
+ size_t mProducerSize;
+ unsigned long mPosMask;
+ android::base::unique_fd mRingFd;
+
+ void* mDataPos = nullptr;
+ unsigned long* mConsumerPos = nullptr;
+ unsigned long* mProducerPos = nullptr;
+};
+
+// This is a class wrapper for eBPF ring buffers. An eBPF ring buffer is a
+// special type of eBPF map used for sending messages from eBPF to userspace.
+// The implementation relies on fast shared memory and atomics for the producer
+// and consumer management. Ring buffers are a faster alternative to eBPF perf
+// buffers.
+//
+// This class is thread compatible, but not thread safe.
+//
+// Note: A kernel eBPF ring buffer may be accessed by both kernel and userspace
+// processes at the same time. However, the userspace consumers of a given ring
+// buffer all share a single read pointer. There is no guarantee which readers
+// will read which messages.
+template <typename Value>
+class BpfRingbuf : public BpfRingbufBase {
+ public:
+ using MessageCallback = std::function<void(const Value&)>;
+
+ // Creates a ringbuffer wrapper from a pinned path. This initialization will
+ // abort on error. To handle errors, initialize with Create instead.
+ BpfRingbuf(const char* path) : BpfRingbufBase(path, sizeof(Value)) {}
+
+ // Creates a ringbuffer wrapper from a pinned path. There are no guarantees
+ // that the ringbuf outputs messaged of type `Value`, only that they are the
+ // same size. Size is only checked in ConsumeAll.
+ static base::Result<std::unique_ptr<BpfRingbuf<Value>>> Create(
+ const char* path);
+
+ // Consumes all messages from the ring buffer, passing them to the callback.
+ // Returns the number of messages consumed or a non-ok result on error. If the
+ // ring buffer has no pending messages an OK result with count 0 is returned.
+ base::Result<int> ConsumeAll(const MessageCallback& callback);
+
+ private:
+ // Empty ctor for use by Create.
+ BpfRingbuf() : BpfRingbufBase(sizeof(Value)) {}
+};
+
+#define ACCESS_ONCE(x) (*(volatile typeof(x)*)&(x))
+
+#if defined(__i386__) || defined(__x86_64__)
+#define smp_sync() asm volatile("" ::: "memory")
+#elif defined(__aarch64__)
+#define smp_sync() asm volatile("dmb ish" ::: "memory")
+#else
+#define smp_sync() __sync_synchronize()
+#endif
+
+#define smp_store_release(p, v) \
+ do { \
+ smp_sync(); \
+ ACCESS_ONCE(*(p)) = (v); \
+ } while (0)
+
+#define smp_load_acquire(p) \
+ ({ \
+ auto ___p = ACCESS_ONCE(*(p)); \
+ smp_sync(); \
+ ___p; \
+ })
+
+inline base::Result<void> BpfRingbufBase::Init(const char* path) {
+ if (sizeof(unsigned long) != 8) {
+ return android::base::Error()
+ << "BpfRingbuf does not support 32 bit architectures";
+ }
+ mRingFd.reset(mapRetrieveRW(path));
+ if (!mRingFd.ok()) {
+ return android::base::ErrnoError()
+ << "failed to retrieve ringbuffer at " << path;
+ }
+
+ int map_type = android::bpf::bpfGetFdMapType(mRingFd);
+ if (map_type != BPF_MAP_TYPE_RINGBUF) {
+ errno = EINVAL;
+ return android::base::ErrnoError()
+ << "bpf map has wrong type: want BPF_MAP_TYPE_RINGBUF ("
+ << BPF_MAP_TYPE_RINGBUF << ") got " << map_type;
+ }
+
+ int max_entries = android::bpf::bpfGetFdMaxEntries(mRingFd);
+ if (max_entries < 0) {
+ return android::base::ErrnoError()
+ << "failed to read max_entries from ringbuf";
+ }
+ if (max_entries == 0) {
+ errno = EINVAL;
+ return android::base::ErrnoError() << "max_entries must be non-zero";
+ }
+
+ mPosMask = max_entries - 1;
+ mConsumerSize = getpagesize();
+ mProducerSize = getpagesize() + 2 * max_entries;
+
+ {
+ void* ptr = mmap(NULL, mConsumerSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+ mRingFd, 0);
+ if (ptr == MAP_FAILED) {
+ return android::base::ErrnoError()
+ << "failed to mmap ringbuf consumer pages";
+ }
+ mConsumerPos = reinterpret_cast<unsigned long*>(ptr);
+ }
+
+ {
+ void* ptr = mmap(NULL, mProducerSize, PROT_READ, MAP_SHARED, mRingFd,
+ mConsumerSize);
+ if (ptr == MAP_FAILED) {
+ return android::base::ErrnoError()
+ << "failed to mmap ringbuf producer page";
+ }
+ mProducerPos = reinterpret_cast<unsigned long*>(ptr);
+ }
+
+ mDataPos = pointerAddBytes<void*>(mProducerPos, getpagesize());
+ return {};
+}
+
+inline base::Result<int> BpfRingbufBase::ConsumeAll(
+ const std::function<void(const void*)>& callback) {
+ int64_t count = 0;
+ unsigned long cons_pos = smp_load_acquire(mConsumerPos);
+ unsigned long prod_pos = smp_load_acquire(mProducerPos);
+ while (cons_pos < prod_pos) {
+ // Find the start of the entry for this read (wrapping is done here).
+ void* start_ptr = pointerAddBytes<void*>(mDataPos, cons_pos & mPosMask);
+
+ // The entry has an 8 byte header containing the sample length.
+ uint32_t length = smp_load_acquire(reinterpret_cast<uint32_t*>(start_ptr));
+
+ // If the sample isn't committed, we're caught up with the producer.
+ if (length & BPF_RINGBUF_BUSY_BIT) return count;
+
+ cons_pos += roundLength(length);
+
+ if ((length & BPF_RINGBUF_DISCARD_BIT) == 0) {
+ if (length != mValueSize) {
+ smp_store_release(mConsumerPos, cons_pos);
+ errno = EMSGSIZE;
+ return android::base::ErrnoError()
+ << "BPF ring buffer message has unexpected size (want "
+ << mValueSize << " bytes, got " << length << " bytes)";
+ }
+ callback(pointerAddBytes<const void*>(start_ptr, BPF_RINGBUF_HDR_SZ));
+ count++;
+ }
+
+ smp_store_release(mConsumerPos, cons_pos);
+ }
+
+ return count;
+}
+
+template <typename Value>
+inline base::Result<std::unique_ptr<BpfRingbuf<Value>>>
+BpfRingbuf<Value>::Create(const char* path) {
+ auto rb = std::unique_ptr<BpfRingbuf>(new BpfRingbuf);
+ if (auto status = rb->Init(path); !status.ok()) return status.error();
+ return rb;
+}
+
+template <typename Value>
+inline base::Result<int> BpfRingbuf<Value>::ConsumeAll(
+ const MessageCallback& callback) {
+ return BpfRingbufBase::ConsumeAll([&](const void* value) {
+ callback(*reinterpret_cast<const Value*>(value));
+ });
+}
+
+#undef ACCESS_ONCE
+#undef smp_sync
+#undef smp_store_release
+#undef smp_load_acquire
+
+} // namespace bpf
+} // namespace android
diff --git a/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h b/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
index ed1ee51..36865f3 100644
--- a/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
+++ b/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
@@ -235,11 +235,12 @@
/* type safe macro to declare a map and related accessor functions */
#define DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md, \
- selinux, pindir, share) \
+ selinux, pindir, share, min_loader, max_loader, ignore_eng, \
+ ignore_user, ignore_userdebug) \
DEFINE_BPF_MAP_BASE(the_map, TYPE, sizeof(KeyType), sizeof(ValueType), \
num_entries, usr, grp, md, selinux, pindir, share, \
- KVER_NONE, KVER_INF, BPFLOADER_MIN_VER, BPFLOADER_MAX_VER, \
- false, false, false); \
+ KVER_NONE, KVER_INF, min_loader, max_loader, \
+ ignore_eng, ignore_user, ignore_userdebug); \
BPF_MAP_ASSERT_OK(BPF_MAP_TYPE_##TYPE, (num_entries), (md)); \
BPF_ANNOTATE_KV_PAIR(the_map, KeyType, ValueType); \
\
@@ -271,9 +272,11 @@
#error "Bpf Map UID must be left at default of AID_ROOT for BpfLoader prior to v0.28"
#endif
-#define DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md) \
- DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md, \
- DEFAULT_BPF_MAP_SELINUX_CONTEXT, DEFAULT_BPF_MAP_PIN_SUBDIR, false)
+#define DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md) \
+ DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md, \
+ DEFAULT_BPF_MAP_SELINUX_CONTEXT, DEFAULT_BPF_MAP_PIN_SUBDIR, false, \
+ BPFLOADER_MIN_VER, BPFLOADER_MAX_VER, /*ignore_on_eng*/false, \
+ /*ignore_on_user*/false, /*ignore_on_userdebug*/false)
#define DEFINE_BPF_MAP(the_map, TYPE, KeyType, ValueType, num_entries) \
DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \
diff --git a/staticlibs/native/bpf_syscall_wrappers/include/BpfSyscallWrappers.h b/staticlibs/native/bpf_syscall_wrappers/include/BpfSyscallWrappers.h
index f7d6a38..8502961 100644
--- a/staticlibs/native/bpf_syscall_wrappers/include/BpfSyscallWrappers.h
+++ b/staticlibs/native/bpf_syscall_wrappers/include/BpfSyscallWrappers.h
@@ -150,6 +150,18 @@
});
}
+// Available in 4.12 and later kernels.
+inline int runProgram(const BPF_FD_TYPE prog_fd, const void* data,
+ const uint32_t data_size) {
+ return bpf(BPF_PROG_RUN, {
+ .test = {
+ .prog_fd = BPF_FD_TO_U32(prog_fd),
+ .data_in = ptr_to_u64(data),
+ .data_size_in = data_size,
+ },
+ });
+}
+
// BPF_OBJ_GET_INFO_BY_FD requires 4.14+ kernel
//
// Note: some fields are only defined in newer kernels (ie. the map_info struct grows
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt b/staticlibs/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt
index 0236716..49940ea 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt
@@ -17,11 +17,13 @@
package com.android.net.module.util
import com.android.net.module.util.BitUtils.appendStringRepresentationOfBitMaskToStringBuilder
+import com.android.net.module.util.BitUtils.describeDifferences
import com.android.net.module.util.BitUtils.packBits
import com.android.net.module.util.BitUtils.unpackBits
-import org.junit.Test
import kotlin.test.assertEquals
+import kotlin.test.assertNull
import kotlin.test.assertTrue
+import org.junit.Test
class BitUtilsTests {
@Test
@@ -58,4 +60,23 @@
assertEquals(expected, it.toString())
}
}
+
+ @Test
+ fun testDescribeDifferences() {
+ fun describe(a: Long, b: Long) = describeDifferences(a, b, Integer::toString)
+ assertNull(describe(0, 0))
+ assertNull(describe(5, 5))
+ assertNull(describe(Long.MAX_VALUE, Long.MAX_VALUE))
+
+ assertEquals("+0", describe(0, 1))
+ assertEquals("-0", describe(1, 0))
+
+ assertEquals("+0+2", describe(0, 5))
+ assertEquals("+2", describe(1, 5))
+ assertEquals("-0+2", describe(1, 4))
+
+ fun makeField(vararg i: Int) = i.sumOf { 1L shl it }
+ assertEquals("-0-4-6-9+1+3+11", describe(makeField(0, 4, 6, 9), makeField(1, 3, 11)))
+ assertEquals("-1-5-9+6+8", describe(makeField(0, 1, 3, 4, 5, 9), makeField(0, 3, 4, 6, 8)))
+ }
}
diff --git a/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
index eed31e0..d36f52a 100644
--- a/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
+++ b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
@@ -140,20 +140,28 @@
val meteredNc = NetworkCapabilities()
val unmeteredNc = NetworkCapabilities().addCapability(NOT_METERED)
// Check that expecting caps (with or without) fails when no callback has been received.
- assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
- assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+ assertFails {
+ mCallback.expectCaps(matcher, SHORT_TIMEOUT_MS) { it.hasCapability(NOT_METERED) }
+ }
+ assertFails {
+ mCallback.expectCaps(matcher, SHORT_TIMEOUT_MS) { !it.hasCapability(NOT_METERED) }
+ }
// Add NOT_METERED and check that With succeeds and Without fails.
mCallback.onCapabilitiesChanged(net, unmeteredNc)
- mCallback.expectCapabilitiesWith(NOT_METERED, matcher)
+ mCallback.expectCaps(matcher) { it.hasCapability(NOT_METERED) }
mCallback.onCapabilitiesChanged(net, unmeteredNc)
- assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+ assertFails {
+ mCallback.expectCaps(matcher, SHORT_TIMEOUT_MS) { !it.hasCapability(NOT_METERED) }
+ }
// Don't add NOT_METERED and check that With fails and Without succeeds.
mCallback.onCapabilitiesChanged(net, meteredNc)
- assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+ assertFails {
+ mCallback.expectCaps(matcher, SHORT_TIMEOUT_MS) { it.hasCapability(NOT_METERED) }
+ }
mCallback.onCapabilitiesChanged(net, meteredNc)
- mCallback.expectCapabilitiesWithout(NOT_METERED, matcher)
+ mCallback.expectCaps(matcher) { !it.hasCapability(NOT_METERED) }
}
@Test
@@ -179,33 +187,31 @@
}
@Test
- fun testCapabilitiesThat() {
+ fun testExpectCaps() {
val net = Network(101)
val netCaps = NetworkCapabilities().addCapability(NOT_METERED).addTransportType(WIFI)
// Check that expecting capabilitiesThat anything fails when no callback has been received.
- assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { true } }
+ assertFails { mCallback.expectCaps(net, SHORT_TIMEOUT_MS) { true } }
// Basic test for true and false
mCallback.onCapabilitiesChanged(net, netCaps)
- mCallback.expectCapabilitiesThat(net) { true }
+ mCallback.expectCaps(net) { true }
mCallback.onCapabilitiesChanged(net, netCaps)
- assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { false } }
+ assertFails { mCallback.expectCaps(net, SHORT_TIMEOUT_MS) { false } }
// Try a positive and a negative case
mCallback.onCapabilitiesChanged(net, netCaps)
- mCallback.expectCapabilitiesThat(net) { caps ->
- caps.hasCapability(NOT_METERED) &&
- caps.hasTransport(WIFI) &&
- !caps.hasTransport(CELLULAR)
+ mCallback.expectCaps(net) {
+ it.hasCapability(NOT_METERED) && it.hasTransport(WIFI) && !it.hasTransport(CELLULAR)
}
mCallback.onCapabilitiesChanged(net, netCaps)
- assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { caps ->
- caps.hasTransport(CELLULAR)
- } }
+ assertFails { mCallback.expectCaps(net, SHORT_TIMEOUT_MS) { it.hasTransport(CELLULAR) } }
// Try a matching callback on the wrong network
mCallback.onCapabilitiesChanged(net, netCaps)
- assertFails { mCallback.expectCapabilitiesThat(Network(100), SHORT_TIMEOUT_MS) { true } }
+ assertFails {
+ mCallback.expectCaps(Network(100), SHORT_TIMEOUT_MS) { true }
+ }
}
@Test
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
index 68d5fa9..533ec22 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
@@ -414,13 +414,6 @@
crossinline predicate: (T) -> Boolean = { true }
) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T?
- inline fun expectCapabilitiesThat(
- net: Network,
- tmt: Long = defaultTimeoutMs,
- valid: (NetworkCapabilities) -> Boolean
- ): CapabilitiesChanged =
- expect(net, tmt, "Capabilities don't match expectations") { valid(it.caps) }
-
inline fun expectLinkPropertiesThat(
net: Network,
tmt: Long = defaultTimeoutMs,
@@ -472,10 +465,8 @@
if (suspended) {
expect<Suspended>(net, tmt)
}
- expectCapabilitiesThat(net, tmt) {
- validated == null || validated == it.hasCapability(
- NET_CAPABILITY_VALIDATED
- )
+ expect<CapabilitiesChanged>(net, tmt) {
+ validated == null || validated == it.caps.hasCapability(NET_CAPABILITY_VALIDATED)
}
expect<LinkPropertiesChanged>(net, tmt)
}
@@ -514,7 +505,7 @@
// when a network connects and satisfies a callback, and then immediately validates.
fun expectAvailableThenValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
expectAvailableCallbacks(net, validated = false, tmt = tmt)
- expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
+ expectCaps(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
}
fun expectAvailableThenValidatedCallbacks(
@@ -524,7 +515,7 @@
) {
expectAvailableCallbacks(net, validated = false, suspended = false,
blockedStatus = blockedStatus, tmt = tmt)
- expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
+ expectCaps(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
}
// Temporary Java compat measure : have MockNetworkAgent implement this so that all existing
@@ -578,29 +569,28 @@
) = expectLinkPropertiesThat(n.network, tmt, valid)
@JvmOverloads
- fun expectCapabilitiesThat(
+ fun expectCaps(
n: HasNetwork,
tmt: Long = defaultTimeoutMs,
+ valid: (NetworkCapabilities) -> Boolean = { true }
+ ) = expect<CapabilitiesChanged>(n.network, tmt) { valid(it.caps) }.caps
+
+ @JvmOverloads
+ fun expectCaps(
+ n: Network,
+ tmt: Long = defaultTimeoutMs,
valid: (NetworkCapabilities) -> Boolean
- ) = expectCapabilitiesThat(n.network, tmt, valid)
+ ) = expect<CapabilitiesChanged>(n, tmt) { valid(it.caps) }.caps
- @JvmOverloads
- fun expectCapabilitiesWith(
- capability: Int,
+ fun expectCaps(
n: HasNetwork,
- timeoutMs: Long = defaultTimeoutMs
- ): NetworkCapabilities {
- return expectCapabilitiesThat(n.network, timeoutMs) { it.hasCapability(capability) }.caps
- }
+ valid: (NetworkCapabilities) -> Boolean
+ ) = expect<CapabilitiesChanged>(n.network) { valid(it.caps) }.caps
- @JvmOverloads
- fun expectCapabilitiesWithout(
- capability: Int,
- n: HasNetwork,
- timeoutMs: Long = defaultTimeoutMs
- ): NetworkCapabilities {
- return expectCapabilitiesThat(n.network, timeoutMs) { !it.hasCapability(capability) }.caps
- }
+ fun expectCaps(
+ tmt: Long,
+ valid: (NetworkCapabilities) -> Boolean
+ ) = expect<CapabilitiesChanged>(ANY_NETWORK, tmt) { valid(it.caps) }.caps
fun expectBlockedStatusCallback(expectBlocked: Boolean, n: HasNetwork) {
expectBlockedStatusCallback(expectBlocked, n.network, defaultTimeoutMs)
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/filters/CtsNetTestCasesMaxTargetSdk33.kt b/staticlibs/testutils/devicetests/com/android/testutils/filters/CtsNetTestCasesMaxTargetSdk33.kt
new file mode 100644
index 0000000..5af890f
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/filters/CtsNetTestCasesMaxTargetSdk33.kt
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils.filters
+
+/**
+ * Only run this test in the CtsNetTestCasesMaxTargetSdk33 suite.
+ */
+annotation class CtsNetTestCasesMaxTargetSdk33(val reason: String)