Merge "Add OS access interfaces and test impl for wear tethering"
diff --git a/staticlibs/framework/com/android/net/module/util/BitUtils.java b/staticlibs/framework/com/android/net/module/util/BitUtils.java
index 2b32e86..3062d8c 100644
--- a/staticlibs/framework/com/android/net/module/util/BitUtils.java
+++ b/staticlibs/framework/com/android/net/module/util/BitUtils.java
@@ -17,6 +17,7 @@
 package com.android.net.module.util;
 
 import android.annotation.NonNull;
+import android.annotation.Nullable;
 
 /**
  * @hide
@@ -107,4 +108,33 @@
             ++bitPos;
         }
     }
+
+    /**
+     * Returns a short but human-readable string of updates between an old and a new bit fields.
+     *
+     * @param oldVal the old bit field to diff from
+     * @param newVal the new bit field to diff to
+     * @return a string fit for logging differences, or null if no differences.
+     *         this method cannot return the empty string.
+     */
+    @Nullable
+    public static String describeDifferences(final long oldVal, final long newVal,
+            @NonNull final NameOf nameFetcher) {
+        final long changed = oldVal ^ newVal;
+        if (0 == changed) return null;
+        // If the control reaches here, there are changes (additions, removals, or both) so
+        // the code below is guaranteed to add something to the string and can't return "".
+        final long removed = oldVal & changed;
+        final long added = newVal & changed;
+        final StringBuilder sb = new StringBuilder();
+        if (0 != removed) {
+            sb.append("-");
+            appendStringRepresentationOfBitMaskToStringBuilder(sb, removed, nameFetcher, "-");
+        }
+        if (0 != added) {
+            sb.append("+");
+            appendStringRepresentationOfBitMaskToStringBuilder(sb, added, nameFetcher, "+");
+        }
+        return sb.toString();
+    }
 }
diff --git a/staticlibs/native/bpf_headers/BpfRingbufTest.cpp b/staticlibs/native/bpf_headers/BpfRingbufTest.cpp
index 4a45a93..d23afae 100644
--- a/staticlibs/native/bpf_headers/BpfRingbufTest.cpp
+++ b/staticlibs/native/bpf_headers/BpfRingbufTest.cpp
@@ -23,14 +23,20 @@
 #include <unistd.h>
 
 #include "BpfSyscallWrappers.h"
+#include "bpf/BpfRingbuf.h"
 #include "bpf/BpfUtils.h"
 
+#define TEST_RINGBUF_MAGIC_NUM 12345
+
 namespace android {
 namespace bpf {
 using ::android::base::testing::HasError;
 using ::android::base::testing::HasValue;
-using ::android::base::testing::WithMessage;
+using ::android::base::testing::WithCode;
+using ::testing::AllOf;
+using ::testing::Gt;
 using ::testing::HasSubstr;
+using ::testing::Lt;
 
 class BpfRingbufTest : public ::testing::Test {
  protected:
@@ -40,8 +46,11 @@
 
   void SetUp() {
     if (!android::bpf::isAtLeastKernelVersion(5, 8, 0)) {
-      GTEST_SKIP() << "BPF ring buffers not supported";
-      return;
+      GTEST_SKIP() << "BPF ring buffers not supported below 5.8";
+    }
+
+    if (sizeof(unsigned long) != 8) {
+      GTEST_SKIP() << "BPF ring buffers not supported on 32 bit arch";
     }
 
     errno = 0;
@@ -51,12 +60,82 @@
         << mProgPath << " was either not found or inaccessible.";
   }
 
+  void RunProgram() {
+    char fake_skb[128] = {};
+    EXPECT_EQ(runProgram(mProgram, fake_skb, sizeof(fake_skb)), 0);
+  }
+
+  void RunTestN(int n) {
+    int run_count = 0;
+    uint64_t output = 0;
+    auto callback = [&](const uint64_t& value) {
+      output = value;
+      run_count++;
+    };
+
+    auto result = BpfRingbuf<uint64_t>::Create(mRingbufPath.c_str());
+    ASSERT_RESULT_OK(result);
+
+    for (int i = 0; i < n; i++) {
+      RunProgram();
+    }
+
+    EXPECT_THAT(result.value()->ConsumeAll(callback), HasValue(n));
+    EXPECT_EQ(output, TEST_RINGBUF_MAGIC_NUM);
+    EXPECT_EQ(run_count, n);
+  }
+
   std::string mProgPath;
   std::string mRingbufPath;
   android::base::unique_fd mProgram;
 };
 
-TEST_F(BpfRingbufTest, CheckSetUp) {}
+TEST_F(BpfRingbufTest, ConsumeSingle) { RunTestN(1); }
+TEST_F(BpfRingbufTest, ConsumeMultiple) { RunTestN(3); }
+
+TEST_F(BpfRingbufTest, FillAndWrap) {
+  int run_count = 0;
+  auto callback = [&](const uint64_t&) { run_count++; };
+
+  auto result = BpfRingbuf<uint64_t>::Create(mRingbufPath.c_str());
+  ASSERT_RESULT_OK(result);
+
+  // 4kb buffer with 16 byte payloads (8 byte data, 8 byte header) should fill
+  // after 255 iterations. Exceed that so that some events are dropped.
+  constexpr int iterations = 300;
+  for (int i = 0; i < iterations; i++) {
+    RunProgram();
+  }
+
+  // Some events were dropped, but consume all that succeeded.
+  EXPECT_THAT(result.value()->ConsumeAll(callback),
+              HasValue(AllOf(Gt(250), Lt(260))));
+  EXPECT_THAT(run_count, AllOf(Gt(250), Lt(260)));
+
+  // After consuming everything, we should be able to use the ring buffer again.
+  run_count = 0;
+  RunProgram();
+  EXPECT_THAT(result.value()->ConsumeAll(callback), HasValue(1));
+  EXPECT_EQ(run_count, 1);
+}
+
+TEST_F(BpfRingbufTest, WrongTypeSize) {
+  // The program under test writes 8-byte uint64_t values so a ringbuffer for
+  // 1-byte uint8_t values will fail to read from it. Note that the map_def does
+  // not specify the value size, so we fail on read, not creation.
+  auto result = BpfRingbuf<uint8_t>::Create(mRingbufPath.c_str());
+  ASSERT_RESULT_OK(result);
+
+  RunProgram();
+
+  EXPECT_THAT(result.value()->ConsumeAll([](const uint8_t&) {}),
+              HasError(WithCode(EMSGSIZE)));
+}
+
+TEST_F(BpfRingbufTest, InvalidPath) {
+  EXPECT_THAT(BpfRingbuf<int>::Create("/sys/fs/bpf/bad_path"),
+              HasError(WithCode(ENOENT)));
+}
 
 }  // namespace bpf
 }  // namespace android
diff --git a/staticlibs/native/bpf_headers/include/bpf/BpfRingbuf.h b/staticlibs/native/bpf_headers/include/bpf/BpfRingbuf.h
new file mode 100644
index 0000000..cac1e43
--- /dev/null
+++ b/staticlibs/native/bpf_headers/include/bpf/BpfRingbuf.h
@@ -0,0 +1,261 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <android-base/result.h>
+#include <android-base/unique_fd.h>
+#include <linux/bpf.h>
+#include <sys/mman.h>
+#include <utils/Log.h>
+
+#include "bpf/BpfUtils.h"
+
+namespace android {
+namespace bpf {
+
+// BpfRingbufBase contains the non-templated functionality of BPF ring buffers.
+class BpfRingbufBase {
+ public:
+  ~BpfRingbufBase() {
+    if (mConsumerPos) munmap(mConsumerPos, mConsumerSize);
+    if (mProducerPos) munmap(mProducerPos, mProducerSize);
+    mConsumerPos = nullptr;
+    mProducerPos = nullptr;
+  }
+
+ protected:
+  // Non-initializing constructor, used by Create.
+  BpfRingbufBase(size_t value_size) : mValueSize(value_size) {}
+
+  // Full construction that aborts on error (use Create/Init to handle errors).
+  BpfRingbufBase(const char* path, size_t value_size) : mValueSize(value_size) {
+    if (auto status = Init(path); !status.ok()) {
+      ALOGE("BpfRingbuf init failed: %s", status.error().message().c_str());
+      abort();
+    }
+  }
+
+  // Delete copy constructor (class owns raw pointers).
+  BpfRingbufBase(const BpfRingbufBase&) = delete;
+
+  // Initialize the base ringbuffer components. Must be called exactly once.
+  base::Result<void> Init(const char* path);
+
+  // Consumes all messages from the ring buffer, passing them to the callback.
+  base::Result<int> ConsumeAll(
+      const std::function<void(const void*)>& callback);
+
+  // Replicates c-style void* "byte-wise" pointer addition.
+  template <typename Ptr>
+  static Ptr pointerAddBytes(void* base, ssize_t offset_bytes) {
+    return reinterpret_cast<Ptr>(reinterpret_cast<char*>(base) + offset_bytes);
+  }
+
+  // Rounds len by clearing bitmask, adding header, and aligning to 8 bytes.
+  static uint32_t roundLength(uint32_t len) {
+    len &= ~(BPF_RINGBUF_BUSY_BIT | BPF_RINGBUF_DISCARD_BIT);
+    len += BPF_RINGBUF_HDR_SZ;
+    return (len + 7) & ~7;
+  }
+
+  const size_t mValueSize;
+
+  size_t mConsumerSize;
+  size_t mProducerSize;
+  unsigned long mPosMask;
+  android::base::unique_fd mRingFd;
+
+  void* mDataPos = nullptr;
+  unsigned long* mConsumerPos = nullptr;
+  unsigned long* mProducerPos = nullptr;
+};
+
+// This is a class wrapper for eBPF ring buffers. An eBPF ring buffer is a
+// special type of eBPF map used for sending messages from eBPF to userspace.
+// The implementation relies on fast shared memory and atomics for the producer
+// and consumer management. Ring buffers are a faster alternative to eBPF perf
+// buffers.
+//
+// This class is thread compatible, but not thread safe.
+//
+// Note: A kernel eBPF ring buffer may be accessed by both kernel and userspace
+// processes at the same time. However, the userspace consumers of a given ring
+// buffer all share a single read pointer. There is no guarantee which readers
+// will read which messages.
+template <typename Value>
+class BpfRingbuf : public BpfRingbufBase {
+ public:
+  using MessageCallback = std::function<void(const Value&)>;
+
+  // Creates a ringbuffer wrapper from a pinned path. This initialization will
+  // abort on error. To handle errors, initialize with Create instead.
+  BpfRingbuf(const char* path) : BpfRingbufBase(path, sizeof(Value)) {}
+
+  // Creates a ringbuffer wrapper from a pinned path. There are no guarantees
+  // that the ringbuf outputs messaged of type `Value`, only that they are the
+  // same size. Size is only checked in ConsumeAll.
+  static base::Result<std::unique_ptr<BpfRingbuf<Value>>> Create(
+      const char* path);
+
+  // Consumes all messages from the ring buffer, passing them to the callback.
+  // Returns the number of messages consumed or a non-ok result on error. If the
+  // ring buffer has no pending messages an OK result with count 0 is returned.
+  base::Result<int> ConsumeAll(const MessageCallback& callback);
+
+ private:
+  // Empty ctor for use by Create.
+  BpfRingbuf() : BpfRingbufBase(sizeof(Value)) {}
+};
+
+#define ACCESS_ONCE(x) (*(volatile typeof(x)*)&(x))
+
+#if defined(__i386__) || defined(__x86_64__)
+#define smp_sync() asm volatile("" ::: "memory")
+#elif defined(__aarch64__)
+#define smp_sync() asm volatile("dmb ish" ::: "memory")
+#else
+#define smp_sync() __sync_synchronize()
+#endif
+
+#define smp_store_release(p, v) \
+  do {                          \
+    smp_sync();                 \
+    ACCESS_ONCE(*(p)) = (v);    \
+  } while (0)
+
+#define smp_load_acquire(p)        \
+  ({                               \
+    auto ___p = ACCESS_ONCE(*(p)); \
+    smp_sync();                    \
+    ___p;                          \
+  })
+
+inline base::Result<void> BpfRingbufBase::Init(const char* path) {
+  if (sizeof(unsigned long) != 8) {
+    return android::base::Error()
+           << "BpfRingbuf does not support 32 bit architectures";
+  }
+  mRingFd.reset(mapRetrieveRW(path));
+  if (!mRingFd.ok()) {
+    return android::base::ErrnoError()
+           << "failed to retrieve ringbuffer at " << path;
+  }
+
+  int map_type = android::bpf::bpfGetFdMapType(mRingFd);
+  if (map_type != BPF_MAP_TYPE_RINGBUF) {
+    errno = EINVAL;
+    return android::base::ErrnoError()
+           << "bpf map has wrong type: want BPF_MAP_TYPE_RINGBUF ("
+           << BPF_MAP_TYPE_RINGBUF << ") got " << map_type;
+  }
+
+  int max_entries = android::bpf::bpfGetFdMaxEntries(mRingFd);
+  if (max_entries < 0) {
+    return android::base::ErrnoError()
+           << "failed to read max_entries from ringbuf";
+  }
+  if (max_entries == 0) {
+    errno = EINVAL;
+    return android::base::ErrnoError() << "max_entries must be non-zero";
+  }
+
+  mPosMask = max_entries - 1;
+  mConsumerSize = getpagesize();
+  mProducerSize = getpagesize() + 2 * max_entries;
+
+  {
+    void* ptr = mmap(NULL, mConsumerSize, PROT_READ | PROT_WRITE, MAP_SHARED,
+                     mRingFd, 0);
+    if (ptr == MAP_FAILED) {
+      return android::base::ErrnoError()
+             << "failed to mmap ringbuf consumer pages";
+    }
+    mConsumerPos = reinterpret_cast<unsigned long*>(ptr);
+  }
+
+  {
+    void* ptr = mmap(NULL, mProducerSize, PROT_READ, MAP_SHARED, mRingFd,
+                     mConsumerSize);
+    if (ptr == MAP_FAILED) {
+      return android::base::ErrnoError()
+             << "failed to mmap ringbuf producer page";
+    }
+    mProducerPos = reinterpret_cast<unsigned long*>(ptr);
+  }
+
+  mDataPos = pointerAddBytes<void*>(mProducerPos, getpagesize());
+  return {};
+}
+
+inline base::Result<int> BpfRingbufBase::ConsumeAll(
+    const std::function<void(const void*)>& callback) {
+  int64_t count = 0;
+  unsigned long cons_pos = smp_load_acquire(mConsumerPos);
+  unsigned long prod_pos = smp_load_acquire(mProducerPos);
+  while (cons_pos < prod_pos) {
+    // Find the start of the entry for this read (wrapping is done here).
+    void* start_ptr = pointerAddBytes<void*>(mDataPos, cons_pos & mPosMask);
+
+    // The entry has an 8 byte header containing the sample length.
+    uint32_t length = smp_load_acquire(reinterpret_cast<uint32_t*>(start_ptr));
+
+    // If the sample isn't committed, we're caught up with the producer.
+    if (length & BPF_RINGBUF_BUSY_BIT) return count;
+
+    cons_pos += roundLength(length);
+
+    if ((length & BPF_RINGBUF_DISCARD_BIT) == 0) {
+      if (length != mValueSize) {
+        smp_store_release(mConsumerPos, cons_pos);
+        errno = EMSGSIZE;
+        return android::base::ErrnoError()
+               << "BPF ring buffer message has unexpected size (want "
+               << mValueSize << " bytes, got " << length << " bytes)";
+      }
+      callback(pointerAddBytes<const void*>(start_ptr, BPF_RINGBUF_HDR_SZ));
+      count++;
+    }
+
+    smp_store_release(mConsumerPos, cons_pos);
+  }
+
+  return count;
+}
+
+template <typename Value>
+inline base::Result<std::unique_ptr<BpfRingbuf<Value>>>
+BpfRingbuf<Value>::Create(const char* path) {
+  auto rb = std::unique_ptr<BpfRingbuf>(new BpfRingbuf);
+  if (auto status = rb->Init(path); !status.ok()) return status.error();
+  return rb;
+}
+
+template <typename Value>
+inline base::Result<int> BpfRingbuf<Value>::ConsumeAll(
+    const MessageCallback& callback) {
+  return BpfRingbufBase::ConsumeAll([&](const void* value) {
+    callback(*reinterpret_cast<const Value*>(value));
+  });
+}
+
+#undef ACCESS_ONCE
+#undef smp_sync
+#undef smp_store_release
+#undef smp_load_acquire
+
+}  // namespace bpf
+}  // namespace android
diff --git a/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h b/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
index ed1ee51..36865f3 100644
--- a/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
+++ b/staticlibs/native/bpf_headers/include/bpf/bpf_helpers.h
@@ -235,11 +235,12 @@
 
 /* type safe macro to declare a map and related accessor functions */
 #define DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md,         \
-                           selinux, pindir, share)                                               \
+                           selinux, pindir, share, min_loader, max_loader, ignore_eng,           \
+                           ignore_user, ignore_userdebug)                                        \
   DEFINE_BPF_MAP_BASE(the_map, TYPE, sizeof(KeyType), sizeof(ValueType),                         \
                       num_entries, usr, grp, md, selinux, pindir, share,                         \
-                      KVER_NONE, KVER_INF, BPFLOADER_MIN_VER, BPFLOADER_MAX_VER,                 \
-                      false, false, false);                                                      \
+                      KVER_NONE, KVER_INF, min_loader, max_loader,                               \
+                      ignore_eng, ignore_user, ignore_userdebug);                                \
     BPF_MAP_ASSERT_OK(BPF_MAP_TYPE_##TYPE, (num_entries), (md));                                 \
     BPF_ANNOTATE_KV_PAIR(the_map, KeyType, ValueType);                                           \
                                                                                                  \
@@ -271,9 +272,11 @@
 #error "Bpf Map UID must be left at default of AID_ROOT for BpfLoader prior to v0.28"
 #endif
 
-#define DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md) \
-    DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md, \
-                       DEFAULT_BPF_MAP_SELINUX_CONTEXT, DEFAULT_BPF_MAP_PIN_SUBDIR, false)
+#define DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md)   \
+    DEFINE_BPF_MAP_EXT(the_map, TYPE, KeyType, ValueType, num_entries, usr, grp, md,       \
+                       DEFAULT_BPF_MAP_SELINUX_CONTEXT, DEFAULT_BPF_MAP_PIN_SUBDIR, false, \
+                       BPFLOADER_MIN_VER, BPFLOADER_MAX_VER, /*ignore_on_eng*/false,       \
+                       /*ignore_on_user*/false, /*ignore_on_userdebug*/false)
 
 #define DEFINE_BPF_MAP(the_map, TYPE, KeyType, ValueType, num_entries) \
     DEFINE_BPF_MAP_UGM(the_map, TYPE, KeyType, ValueType, num_entries, \
diff --git a/staticlibs/native/bpf_syscall_wrappers/include/BpfSyscallWrappers.h b/staticlibs/native/bpf_syscall_wrappers/include/BpfSyscallWrappers.h
index f7d6a38..8502961 100644
--- a/staticlibs/native/bpf_syscall_wrappers/include/BpfSyscallWrappers.h
+++ b/staticlibs/native/bpf_syscall_wrappers/include/BpfSyscallWrappers.h
@@ -150,6 +150,18 @@
                                 });
 }
 
+// Available in 4.12 and later kernels.
+inline int runProgram(const BPF_FD_TYPE prog_fd, const void* data,
+                      const uint32_t data_size) {
+    return bpf(BPF_PROG_RUN, {
+                                     .test = {
+                                             .prog_fd = BPF_FD_TO_U32(prog_fd),
+                                             .data_in = ptr_to_u64(data),
+                                             .data_size_in = data_size,
+                                     },
+                             });
+}
+
 // BPF_OBJ_GET_INFO_BY_FD requires 4.14+ kernel
 //
 // Note: some fields are only defined in newer kernels (ie. the map_info struct grows
diff --git a/staticlibs/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt b/staticlibs/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt
index 0236716..49940ea 100644
--- a/staticlibs/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt
+++ b/staticlibs/tests/unit/src/com/android/net/module/util/BitUtilsTests.kt
@@ -17,11 +17,13 @@
 package com.android.net.module.util
 
 import com.android.net.module.util.BitUtils.appendStringRepresentationOfBitMaskToStringBuilder
+import com.android.net.module.util.BitUtils.describeDifferences
 import com.android.net.module.util.BitUtils.packBits
 import com.android.net.module.util.BitUtils.unpackBits
-import org.junit.Test
 import kotlin.test.assertEquals
+import kotlin.test.assertNull
 import kotlin.test.assertTrue
+import org.junit.Test
 
 class BitUtilsTests {
     @Test
@@ -58,4 +60,23 @@
             assertEquals(expected, it.toString())
         }
     }
+
+    @Test
+    fun testDescribeDifferences() {
+        fun describe(a: Long, b: Long) = describeDifferences(a, b, Integer::toString)
+        assertNull(describe(0, 0))
+        assertNull(describe(5, 5))
+        assertNull(describe(Long.MAX_VALUE, Long.MAX_VALUE))
+
+        assertEquals("+0", describe(0, 1))
+        assertEquals("-0", describe(1, 0))
+
+        assertEquals("+0+2", describe(0, 5))
+        assertEquals("+2", describe(1, 5))
+        assertEquals("-0+2", describe(1, 4))
+
+        fun makeField(vararg i: Int) = i.sumOf { 1L shl it }
+        assertEquals("-0-4-6-9+1+3+11", describe(makeField(0, 4, 6, 9), makeField(1, 3, 11)))
+        assertEquals("-1-5-9+6+8", describe(makeField(0, 1, 3, 4, 5, 9), makeField(0, 3, 4, 6, 8)))
+    }
 }
diff --git a/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
index eed31e0..d36f52a 100644
--- a/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
+++ b/staticlibs/tests/unit/src/com/android/testutils/TestableNetworkCallbackTest.kt
@@ -140,20 +140,28 @@
         val meteredNc = NetworkCapabilities()
         val unmeteredNc = NetworkCapabilities().addCapability(NOT_METERED)
         // Check that expecting caps (with or without) fails when no callback has been received.
-        assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
-        assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+        assertFails {
+            mCallback.expectCaps(matcher, SHORT_TIMEOUT_MS) { it.hasCapability(NOT_METERED) }
+        }
+        assertFails {
+            mCallback.expectCaps(matcher, SHORT_TIMEOUT_MS) { !it.hasCapability(NOT_METERED) }
+        }
 
         // Add NOT_METERED and check that With succeeds and Without fails.
         mCallback.onCapabilitiesChanged(net, unmeteredNc)
-        mCallback.expectCapabilitiesWith(NOT_METERED, matcher)
+        mCallback.expectCaps(matcher) { it.hasCapability(NOT_METERED) }
         mCallback.onCapabilitiesChanged(net, unmeteredNc)
-        assertFails { mCallback.expectCapabilitiesWithout(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+        assertFails {
+            mCallback.expectCaps(matcher, SHORT_TIMEOUT_MS) { !it.hasCapability(NOT_METERED) }
+        }
 
         // Don't add NOT_METERED and check that With fails and Without succeeds.
         mCallback.onCapabilitiesChanged(net, meteredNc)
-        assertFails { mCallback.expectCapabilitiesWith(NOT_METERED, matcher, SHORT_TIMEOUT_MS) }
+        assertFails {
+            mCallback.expectCaps(matcher, SHORT_TIMEOUT_MS) { it.hasCapability(NOT_METERED) }
+        }
         mCallback.onCapabilitiesChanged(net, meteredNc)
-        mCallback.expectCapabilitiesWithout(NOT_METERED, matcher)
+        mCallback.expectCaps(matcher) { !it.hasCapability(NOT_METERED) }
     }
 
     @Test
@@ -179,33 +187,31 @@
     }
 
     @Test
-    fun testCapabilitiesThat() {
+    fun testExpectCaps() {
         val net = Network(101)
         val netCaps = NetworkCapabilities().addCapability(NOT_METERED).addTransportType(WIFI)
         // Check that expecting capabilitiesThat anything fails when no callback has been received.
-        assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { true } }
+        assertFails { mCallback.expectCaps(net, SHORT_TIMEOUT_MS) { true } }
 
         // Basic test for true and false
         mCallback.onCapabilitiesChanged(net, netCaps)
-        mCallback.expectCapabilitiesThat(net) { true }
+        mCallback.expectCaps(net) { true }
         mCallback.onCapabilitiesChanged(net, netCaps)
-        assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { false } }
+        assertFails { mCallback.expectCaps(net, SHORT_TIMEOUT_MS) { false } }
 
         // Try a positive and a negative case
         mCallback.onCapabilitiesChanged(net, netCaps)
-        mCallback.expectCapabilitiesThat(net) { caps ->
-            caps.hasCapability(NOT_METERED) &&
-                    caps.hasTransport(WIFI) &&
-                    !caps.hasTransport(CELLULAR)
+        mCallback.expectCaps(net) {
+            it.hasCapability(NOT_METERED) && it.hasTransport(WIFI) && !it.hasTransport(CELLULAR)
         }
         mCallback.onCapabilitiesChanged(net, netCaps)
-        assertFails { mCallback.expectCapabilitiesThat(net, SHORT_TIMEOUT_MS) { caps ->
-            caps.hasTransport(CELLULAR)
-        } }
+        assertFails { mCallback.expectCaps(net, SHORT_TIMEOUT_MS) { it.hasTransport(CELLULAR) } }
 
         // Try a matching callback on the wrong network
         mCallback.onCapabilitiesChanged(net, netCaps)
-        assertFails { mCallback.expectCapabilitiesThat(Network(100), SHORT_TIMEOUT_MS) { true } }
+        assertFails {
+            mCallback.expectCaps(Network(100), SHORT_TIMEOUT_MS) { true }
+        }
     }
 
     @Test
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
index 68d5fa9..533ec22 100644
--- a/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
+++ b/staticlibs/testutils/devicetests/com/android/testutils/TestableNetworkCallback.kt
@@ -414,13 +414,6 @@
         crossinline predicate: (T) -> Boolean = { true }
     ) = history.poll(timeoutMs, from) { it is T && predicate(it) } as T?
 
-    inline fun expectCapabilitiesThat(
-        net: Network,
-        tmt: Long = defaultTimeoutMs,
-        valid: (NetworkCapabilities) -> Boolean
-    ): CapabilitiesChanged =
-            expect(net, tmt, "Capabilities don't match expectations") { valid(it.caps) }
-
     inline fun expectLinkPropertiesThat(
         net: Network,
         tmt: Long = defaultTimeoutMs,
@@ -472,10 +465,8 @@
         if (suspended) {
             expect<Suspended>(net, tmt)
         }
-        expectCapabilitiesThat(net, tmt) {
-            validated == null || validated == it.hasCapability(
-                NET_CAPABILITY_VALIDATED
-            )
+        expect<CapabilitiesChanged>(net, tmt) {
+            validated == null || validated == it.caps.hasCapability(NET_CAPABILITY_VALIDATED)
         }
         expect<LinkPropertiesChanged>(net, tmt)
     }
@@ -514,7 +505,7 @@
     // when a network connects and satisfies a callback, and then immediately validates.
     fun expectAvailableThenValidatedCallbacks(net: Network, tmt: Long = defaultTimeoutMs) {
         expectAvailableCallbacks(net, validated = false, tmt = tmt)
-        expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
+        expectCaps(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
     }
 
     fun expectAvailableThenValidatedCallbacks(
@@ -524,7 +515,7 @@
     ) {
         expectAvailableCallbacks(net, validated = false, suspended = false,
                 blockedStatus = blockedStatus, tmt = tmt)
-        expectCapabilitiesThat(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
+        expectCaps(net, tmt) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
     }
 
     // Temporary Java compat measure : have MockNetworkAgent implement this so that all existing
@@ -578,29 +569,28 @@
     ) = expectLinkPropertiesThat(n.network, tmt, valid)
 
     @JvmOverloads
-    fun expectCapabilitiesThat(
+    fun expectCaps(
         n: HasNetwork,
         tmt: Long = defaultTimeoutMs,
+        valid: (NetworkCapabilities) -> Boolean = { true }
+    ) = expect<CapabilitiesChanged>(n.network, tmt) { valid(it.caps) }.caps
+
+    @JvmOverloads
+    fun expectCaps(
+        n: Network,
+        tmt: Long = defaultTimeoutMs,
         valid: (NetworkCapabilities) -> Boolean
-    ) = expectCapabilitiesThat(n.network, tmt, valid)
+    ) = expect<CapabilitiesChanged>(n, tmt) { valid(it.caps) }.caps
 
-    @JvmOverloads
-    fun expectCapabilitiesWith(
-        capability: Int,
+    fun expectCaps(
         n: HasNetwork,
-        timeoutMs: Long = defaultTimeoutMs
-    ): NetworkCapabilities {
-        return expectCapabilitiesThat(n.network, timeoutMs) { it.hasCapability(capability) }.caps
-    }
+        valid: (NetworkCapabilities) -> Boolean
+    ) = expect<CapabilitiesChanged>(n.network) { valid(it.caps) }.caps
 
-    @JvmOverloads
-    fun expectCapabilitiesWithout(
-        capability: Int,
-        n: HasNetwork,
-        timeoutMs: Long = defaultTimeoutMs
-    ): NetworkCapabilities {
-        return expectCapabilitiesThat(n.network, timeoutMs) { !it.hasCapability(capability) }.caps
-    }
+    fun expectCaps(
+        tmt: Long,
+        valid: (NetworkCapabilities) -> Boolean
+    ) = expect<CapabilitiesChanged>(ANY_NETWORK, tmt) { valid(it.caps) }.caps
 
     fun expectBlockedStatusCallback(expectBlocked: Boolean, n: HasNetwork) {
         expectBlockedStatusCallback(expectBlocked, n.network, defaultTimeoutMs)
diff --git a/staticlibs/testutils/devicetests/com/android/testutils/filters/CtsNetTestCasesMaxTargetSdk33.kt b/staticlibs/testutils/devicetests/com/android/testutils/filters/CtsNetTestCasesMaxTargetSdk33.kt
new file mode 100644
index 0000000..5af890f
--- /dev/null
+++ b/staticlibs/testutils/devicetests/com/android/testutils/filters/CtsNetTestCasesMaxTargetSdk33.kt
@@ -0,0 +1,22 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.testutils.filters
+
+/**
+ * Only run this test in the CtsNetTestCasesMaxTargetSdk33 suite.
+ */
+annotation class CtsNetTestCasesMaxTargetSdk33(val reason: String)