Merge "Set config_networkAvoidBadWifi overlay for Verizon"
diff --git a/Cronet/tests/cts/src/android/net/http/cts/HttpEngineTest.java b/Cronet/tests/cts/src/android/net/http/cts/HttpEngineTest.java
index 6a8467c..6d27b43 100644
--- a/Cronet/tests/cts/src/android/net/http/cts/HttpEngineTest.java
+++ b/Cronet/tests/cts/src/android/net/http/cts/HttpEngineTest.java
@@ -17,6 +17,7 @@
 package android.net.http.cts;
 
 import static android.net.http.cts.util.TestUtilsKt.assertOKStatusCode;
+import static android.net.http.cts.util.TestUtilsKt.assumeOKStatusCode;
 import static android.net.http.cts.util.TestUtilsKt.skipIfNoInternetConnection;
 
 import static org.hamcrest.MatcherAssert.assertThat;
@@ -100,13 +101,14 @@
         // We send multiple requests to reduce the flakiness of the test.
         boolean quicWasUsed = false;
         for (int i = 0; i < 5; i++) {
+            mCallback = new TestUrlRequestCallback();
             UrlRequest.Builder builder =
                     mEngine.newUrlRequestBuilder(URL, mCallback, mCallback.getExecutor());
             builder.build().start();
 
             mCallback.expectCallback(ResponseStep.ON_SUCCEEDED);
             UrlResponseInfo info = mCallback.mResponseInfo;
-            assertOKStatusCode(info);
+            assumeOKStatusCode(info);
             quicWasUsed = isQuic(info.getNegotiatedProtocol());
             if (quicWasUsed) {
                 break;
diff --git a/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java b/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java
index d7d3679..5256bae 100644
--- a/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java
+++ b/Cronet/tests/cts/src/android/net/http/cts/UrlRequestTest.java
@@ -21,6 +21,8 @@
 
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.greaterThan;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertSame;
 
 import android.content.Context;
 import android.net.http.HttpEngine;
@@ -29,6 +31,7 @@
 import android.net.http.UrlResponseInfo;
 import android.net.http.cts.util.HttpCtsTestServer;
 import android.net.http.cts.util.TestStatusListener;
+import android.net.http.cts.util.TestUploadDataProvider;
 import android.net.http.cts.util.TestUrlRequestCallback;
 import android.net.http.cts.util.TestUrlRequestCallback.ResponseStep;
 
@@ -66,14 +69,14 @@
         }
     }
 
-    private UrlRequest buildUrlRequest(String url) {
-        return mHttpEngine.newUrlRequestBuilder(url, mCallback, mCallback.getExecutor()).build();
+    private UrlRequest.Builder createUrlRequestBuilder(String url) {
+        return mHttpEngine.newUrlRequestBuilder(url, mCallback, mCallback.getExecutor());
     }
 
     @Test
     public void testUrlRequestGet_CompletesSuccessfully() throws Exception {
         String url = mTestServer.getSuccessUrl();
-        UrlRequest request = buildUrlRequest(url);
+        UrlRequest request = createUrlRequestBuilder(url).build();
         request.start();
 
         mCallback.expectCallback(ResponseStep.ON_SUCCEEDED);
@@ -84,11 +87,48 @@
 
     @Test
     public void testUrlRequestStatus_InvalidBeforeRequestStarts() throws Exception {
-        UrlRequest request = buildUrlRequest(mTestServer.getSuccessUrl());
+        UrlRequest request = createUrlRequestBuilder(mTestServer.getSuccessUrl()).build();
         // Calling before request is started should give Status.INVALID,
         // since the native adapter is not created.
         TestStatusListener statusListener = new TestStatusListener();
         request.getStatus(statusListener);
         statusListener.expectStatus(Status.INVALID);
     }
+
+    @Test
+    public void testUrlRequestCancel_CancelCalled() throws Exception {
+        UrlRequest request = createUrlRequestBuilder(mTestServer.getSuccessUrl()).build();
+        mCallback.setAutoAdvance(false);
+
+        request.start();
+        mCallback.waitForNextStep();
+        assertSame(mCallback.mResponseStep, ResponseStep.ON_RESPONSE_STARTED);
+
+        request.cancel();
+        mCallback.expectCallback(ResponseStep.ON_CANCELED);
+    }
+
+    @Test
+    public void testUrlRequestPost_EchoRequestBody() throws Exception {
+        String testData = "test";
+        UrlRequest.Builder builder = createUrlRequestBuilder(mTestServer.getEchoBodyUrl());
+
+        TestUploadDataProvider dataProvider = new TestUploadDataProvider(
+                TestUploadDataProvider.SuccessCallbackMode.SYNC, mCallback.getExecutor());
+        dataProvider.addRead(testData.getBytes());
+        builder.setUploadDataProvider(dataProvider, mCallback.getExecutor());
+        builder.addHeader("Content-Type", "text/html");
+        builder.build().start();
+        mCallback.expectCallback(ResponseStep.ON_SUCCEEDED);
+
+        assertOKStatusCode(mCallback.mResponseInfo);
+        assertEquals(testData, mCallback.mResponseAsString);
+        dataProvider.assertClosed();
+    }
+
+    @Test
+    public void testUrlRequestFail_FailedCalled() throws Exception {
+        createUrlRequestBuilder("http://0.0.0.0:0/").build().start();
+        mCallback.expectCallback(ResponseStep.ON_FAILED);
+    }
 }
diff --git a/Cronet/tests/cts/src/android/net/http/cts/util/HttpCtsTestServer.kt b/Cronet/tests/cts/src/android/net/http/cts/util/HttpCtsTestServer.kt
index 87d5108..5196544 100644
--- a/Cronet/tests/cts/src/android/net/http/cts/util/HttpCtsTestServer.kt
+++ b/Cronet/tests/cts/src/android/net/http/cts/util/HttpCtsTestServer.kt
@@ -18,9 +18,40 @@
 
 import android.content.Context
 import android.webkit.cts.CtsTestServer
+import java.net.URI
+import org.apache.http.HttpEntityEnclosingRequest
+import org.apache.http.HttpRequest
+import org.apache.http.HttpResponse
+import org.apache.http.HttpStatus
+import org.apache.http.HttpVersion
+import org.apache.http.message.BasicHttpResponse
+
+private const val ECHO_BODY_PATH = "/echo_body"
 
 /** Extends CtsTestServer to handle POST requests and other test specific requests */
 class HttpCtsTestServer(context: Context) : CtsTestServer(context) {
 
+    val echoBodyUrl: String = baseUri + ECHO_BODY_PATH
     val successUrl: String = getAssetUrl("html/hello_world.html")
+
+    override fun onPost(req: HttpRequest): HttpResponse? {
+        val path = URI.create(req.requestLine.uri).path
+        var response: HttpResponse? = null
+
+        if (path.startsWith(ECHO_BODY_PATH)) {
+            if (req !is HttpEntityEnclosingRequest) {
+                return BasicHttpResponse(
+                    HttpVersion.HTTP_1_0,
+                    HttpStatus.SC_INTERNAL_SERVER_ERROR,
+                    "Expected req to be of type HttpEntityEnclosingRequest but got ${req.javaClass}"
+                )
+            }
+
+            response = BasicHttpResponse(HttpVersion.HTTP_1_0, HttpStatus.SC_OK, null)
+            response.entity = req.entity
+            response.addHeader("Content-Length", req.entity.contentLength.toString())
+        }
+
+        return response
+    }
 }
diff --git a/Cronet/tests/cts/src/android/net/http/cts/util/TestUtils.kt b/Cronet/tests/cts/src/android/net/http/cts/util/TestUtils.kt
index d30c059..23ec2c8 100644
--- a/Cronet/tests/cts/src/android/net/http/cts/util/TestUtils.kt
+++ b/Cronet/tests/cts/src/android/net/http/cts/util/TestUtils.kt
@@ -19,8 +19,10 @@
 import android.content.Context
 import android.net.ConnectivityManager
 import android.net.http.UrlResponseInfo
+import org.hamcrest.Matchers.equalTo
 import org.junit.Assert.assertEquals
 import org.junit.Assume.assumeNotNull
+import org.junit.Assume.assumeThat
 
 fun skipIfNoInternetConnection(context: Context) {
     val connectivityManager = context.getSystemService(ConnectivityManager::class.java)
@@ -29,5 +31,9 @@
 }
 
 fun assertOKStatusCode(info: UrlResponseInfo) {
-    assertEquals("Status code must be 200 OK", 200, info.getHttpStatusCode())
+    assertEquals("Status code must be 200 OK", 200, info.httpStatusCode)
+}
+
+fun assumeOKStatusCode(info: UrlResponseInfo) {
+    assumeThat("Status code must be 200 OK", info.getHttpStatusCode(), equalTo(200))
 }
diff --git a/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java b/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
index 44d3ffc..9f8d9b1 100644
--- a/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
+++ b/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
@@ -2243,5 +2243,13 @@
         return mTetherClients;
     }
 
+    // Return map of upstream interface IPv4 address to interface index.
+    // This is used for testing only.
+    @NonNull
+    @VisibleForTesting
+    final HashMap<Inet4Address, Integer> getIpv4UpstreamIndicesForTesting() {
+        return mIpv4UpstreamIndices;
+    }
+
     private static native String[] getBpfCounterNames();
 }
diff --git a/Tethering/src/com/android/networkstack/tethering/Tethering.java b/Tethering/src/com/android/networkstack/tethering/Tethering.java
index e48019c..2e71fda 100644
--- a/Tethering/src/com/android/networkstack/tethering/Tethering.java
+++ b/Tethering/src/com/android/networkstack/tethering/Tethering.java
@@ -88,7 +88,6 @@
 import android.net.LinkAddress;
 import android.net.LinkProperties;
 import android.net.Network;
-import android.net.NetworkCapabilities;
 import android.net.NetworkInfo;
 import android.net.TetherStatesParcel;
 import android.net.TetheredClient;
@@ -1856,8 +1855,11 @@
             final Network newUpstream = (ns != null) ? ns.network : null;
             if (mTetherUpstream != newUpstream) {
                 mTetherUpstream = newUpstream;
-                mUpstreamNetworkMonitor.setCurrentUpstream(mTetherUpstream);
-                reportUpstreamChanged(ns);
+                reportUpstreamChanged(mTetherUpstream);
+                // Need to notify capabilities change after upstream network changed because new
+                // network's capabilities should be checked every time.
+                mNotificationUpdater.onUpstreamCapabilitiesChanged(
+                        (ns != null) ? ns.networkCapabilities : null);
             }
         }
 
@@ -2085,6 +2087,7 @@
                 if (mTetherUpstream != null) {
                     mTetherUpstream = null;
                     reportUpstreamChanged(null);
+                    mNotificationUpdater.onUpstreamCapabilitiesChanged(null);
                 }
                 mBpfCoordinator.stopPolling();
             }
@@ -2439,10 +2442,8 @@
         }
     }
 
-    private void reportUpstreamChanged(UpstreamNetworkState ns) {
+    private void reportUpstreamChanged(final Network network) {
         final int length = mTetheringEventCallbacks.beginBroadcast();
-        final Network network = (ns != null) ? ns.network : null;
-        final NetworkCapabilities capabilities = (ns != null) ? ns.networkCapabilities : null;
         try {
             for (int i = 0; i < length; i++) {
                 try {
@@ -2454,9 +2455,6 @@
         } finally {
             mTetheringEventCallbacks.finishBroadcast();
         }
-        // Need to notify capabilities change after upstream network changed because new network's
-        // capabilities should be checked every time.
-        mNotificationUpdater.onUpstreamCapabilitiesChanged(capabilities);
     }
 
     private void reportConfigurationChanged(TetheringConfigurationParcel config) {
diff --git a/Tethering/src/com/android/networkstack/tethering/UpstreamNetworkMonitor.java b/Tethering/src/com/android/networkstack/tethering/UpstreamNetworkMonitor.java
index 16c031b..ac2aa7b 100644
--- a/Tethering/src/com/android/networkstack/tethering/UpstreamNetworkMonitor.java
+++ b/Tethering/src/com/android/networkstack/tethering/UpstreamNetworkMonitor.java
@@ -133,8 +133,6 @@
     private boolean mIsDefaultCellularUpstream;
     // The current system default network (not really used yet).
     private Network mDefaultInternetNetwork;
-    // The current upstream network used for tethering.
-    private Network mTetheringUpstreamNetwork;
     private boolean mPreferTestNetworks;
 
     public UpstreamNetworkMonitor(Context ctx, StateMachine tgt, SharedLog log, int what) {
@@ -191,7 +189,6 @@
         releaseCallback(mListenAllCallback);
         mListenAllCallback = null;
 
-        mTetheringUpstreamNetwork = null;
         mNetworkMap.clear();
     }
 
@@ -342,11 +339,6 @@
         return findFirstDunNetwork(mNetworkMap.values());
     }
 
-    /** Tell UpstreamNetworkMonitor which network is the current upstream of tethering. */
-    public void setCurrentUpstream(Network upstream) {
-        mTetheringUpstreamNetwork = upstream;
-    }
-
     /** Return local prefixes. */
     public Set<IpPrefix> getLocalPrefixes() {
         return (Set<IpPrefix>) mLocalPrefixes.clone();
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
index 1978e99..4f32f3c 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
@@ -16,6 +16,8 @@
 
 package com.android.networkstack.tethering;
 
+import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
+import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
 import static android.net.NetworkStats.DEFAULT_NETWORK_NO;
 import static android.net.NetworkStats.METERED_NO;
 import static android.net.NetworkStats.ROAMING_NO;
@@ -77,6 +79,7 @@
 import android.app.usage.NetworkStatsManager;
 import android.net.INetd;
 import android.net.InetAddresses;
+import android.net.IpPrefix;
 import android.net.LinkAddress;
 import android.net.LinkProperties;
 import android.net.MacAddress;
@@ -89,6 +92,7 @@
 import android.os.Build;
 import android.os.Handler;
 import android.os.test.TestLooper;
+import android.util.SparseArray;
 
 import androidx.annotation.NonNull;
 import androidx.annotation.Nullable;
@@ -156,11 +160,13 @@
 
     private static final int INVALID_IFINDEX = 0;
     private static final int UPSTREAM_IFINDEX = 1001;
+    private static final int UPSTREAM_XLAT_IFINDEX = 1002;
     private static final int UPSTREAM_IFINDEX2 = 1003;
     private static final int DOWNSTREAM_IFINDEX = 2001;
     private static final int DOWNSTREAM_IFINDEX2 = 2002;
 
     private static final String UPSTREAM_IFACE = "rmnet0";
+    private static final String UPSTREAM_XLAT_IFACE = "v4-rmnet0";
     private static final String UPSTREAM_IFACE2 = "wlan0";
 
     private static final MacAddress DOWNSTREAM_MAC = MacAddress.fromString("12:34:56:78:90:ab");
@@ -183,6 +189,10 @@
     private static final Inet4Address PRIVATE_ADDR2 =
             (Inet4Address) InetAddresses.parseNumericAddress("192.168.90.12");
 
+    private static final Inet4Address XLAT_LOCAL_IPV4ADDR =
+            (Inet4Address) InetAddresses.parseNumericAddress("192.0.0.46");
+    private static final IpPrefix NAT64_IP_PREFIX = new IpPrefix("64:ff9b::/96");
+
     // Generally, public port and private port are the same in the NAT conntrack message.
     // TODO: consider using different private port and public port for testing.
     private static final short REMOTE_PORT = (short) 443;
@@ -194,6 +204,10 @@
     private static final InterfaceParams UPSTREAM_IFACE_PARAMS = new InterfaceParams(
             UPSTREAM_IFACE, UPSTREAM_IFINDEX, null /* macAddr, rawip */,
             NetworkStackConstants.ETHER_MTU);
+    private static final InterfaceParams UPSTREAM_XLAT_IFACE_PARAMS = new InterfaceParams(
+            UPSTREAM_XLAT_IFACE, UPSTREAM_XLAT_IFINDEX, null /* macAddr, rawip */,
+            NetworkStackConstants.ETHER_MTU - 28
+            /* mtu delta from external/android-clat/clatd.c */);
     private static final InterfaceParams UPSTREAM_IFACE_PARAMS2 = new InterfaceParams(
             UPSTREAM_IFACE2, UPSTREAM_IFINDEX2, MacAddress.fromString("44:55:66:00:00:0c"),
             NetworkStackConstants.ETHER_MTU);
@@ -2281,4 +2295,170 @@
         verifyAddTetherOffloadRule4Mtu(INVALID_MTU, false /* isKernelMtu */,
                 NetworkStackConstants.ETHER_MTU /* expectedMtu */);
     }
+
+    private static LinkProperties buildUpstreamLinkProperties(final String interfaceName,
+            boolean withIPv4, boolean withIPv6, boolean with464xlat) {
+        final LinkProperties prop = new LinkProperties();
+        prop.setInterfaceName(interfaceName);
+
+        if (withIPv4) {
+            // Assign the address no matter what the interface is. It is okay for now because
+            // only single upstream is available.
+            // TODO: consider to assign address by interface once we need to test two or more
+            // BPF supported upstreams or multi upstreams are supported.
+            prop.addLinkAddress(new LinkAddress(PUBLIC_ADDR, 24));
+        }
+
+        if (withIPv6) {
+            // TODO: make this to be constant. Currently, no test who uses this function cares what
+            // the upstream IPv6 address is.
+            prop.addLinkAddress(new LinkAddress("2001:db8::5175:15ca/64"));
+        }
+
+        if (with464xlat) {
+            final String clatInterface = "v4-" + interfaceName;
+            final LinkProperties stackedLink = new LinkProperties();
+            stackedLink.setInterfaceName(clatInterface);
+            stackedLink.addLinkAddress(new LinkAddress(XLAT_LOCAL_IPV4ADDR, 24));
+            prop.addStackedLink(stackedLink);
+            prop.setNat64Prefix(NAT64_IP_PREFIX);
+        }
+
+        return prop;
+    }
+
+    private void verifyIpv4Upstream(
+            @NonNull final HashMap<Inet4Address, Integer> ipv4UpstreamIndices,
+            @NonNull final SparseArray<String> interfaceNames) {
+        assertEquals(1, ipv4UpstreamIndices.size());
+        Integer upstreamIndex = ipv4UpstreamIndices.get(PUBLIC_ADDR);
+        assertNotNull(upstreamIndex);
+        assertEquals(UPSTREAM_IFINDEX, upstreamIndex.intValue());
+        assertEquals(1, interfaceNames.size());
+        assertTrue(interfaceNames.contains(UPSTREAM_IFINDEX));
+    }
+
+    private void verifyUpdateUpstreamNetworkState()
+            throws Exception {
+        final BpfCoordinator coordinator = makeBpfCoordinator();
+        final HashMap<Inet4Address, Integer> ipv4UpstreamIndices =
+                coordinator.getIpv4UpstreamIndicesForTesting();
+        assertTrue(ipv4UpstreamIndices.isEmpty());
+        final SparseArray<String> interfaceNames =
+                coordinator.getInterfaceNamesForTesting();
+        assertEquals(0, interfaceNames.size());
+
+        // Verify the following are added or removed after upstream changes.
+        // - BpfCoordinator#mIpv4UpstreamIndices (for building IPv4 offload rules)
+        // - BpfCoordinator#mInterfaceNames (for updating limit)
+        //
+        // +-------+-------+-----------------------+
+        // | Test  | Up    |       Protocol        |
+        // | Case# | stream+-------+-------+-------+
+        // |       |       | IPv4  | IPv6  | Xlat  |
+        // +-------+-------+-------+-------+-------+
+        // |   1   | Cell  |   O   |       |       |
+        // +-------+-------+-------+-------+-------+
+        // |   2   | Cell  |       |   O   |       |
+        // +-------+-------+-------+-------+-------+
+        // |   3   | Cell  |   O   |   O   |       |
+        // +-------+-------+-------+-------+-------+
+        // |   4   |   -   |       |       |       |
+        // +-------+-------+-------+-------+-------+
+        // |       | Cell  |   O   |       |       |
+        // |       +-------+-------+-------+-------+
+        // |   5   | Cell  |       |   O   |   O   | <-- doesn't support offload (xlat)
+        // |       +-------+-------+-------+-------+
+        // |       | Cell  |   O   |       |       |
+        // +-------+-------+-------+-------+-------+
+        // |   6   | Wifi  |   O   |   O   |       | <-- doesn't support offload (ether ip)
+        // +-------+-------+-------+-------+-------+
+
+        // [1] Mobile IPv4 only
+        coordinator.addUpstreamNameToLookupTable(UPSTREAM_IFINDEX, UPSTREAM_IFACE);
+        doReturn(UPSTREAM_IFACE_PARAMS).when(mDeps).getInterfaceParams(UPSTREAM_IFACE);
+        final UpstreamNetworkState mobileIPv4UpstreamState = new UpstreamNetworkState(
+                buildUpstreamLinkProperties(UPSTREAM_IFACE,
+                        true /* IPv4 */, false /* IPv6 */, false /* 464xlat */),
+                new NetworkCapabilities().addTransportType(TRANSPORT_CELLULAR),
+                new Network(TEST_NET_ID));
+        coordinator.updateUpstreamNetworkState(mobileIPv4UpstreamState);
+        verifyIpv4Upstream(ipv4UpstreamIndices, interfaceNames);
+
+        // [2] Mobile IPv6 only
+        final UpstreamNetworkState mobileIPv6UpstreamState = new UpstreamNetworkState(
+                buildUpstreamLinkProperties(UPSTREAM_IFACE,
+                        false /* IPv4 */, true /* IPv6 */, false /* 464xlat */),
+                new NetworkCapabilities().addTransportType(TRANSPORT_CELLULAR),
+                new Network(TEST_NET_ID));
+        coordinator.updateUpstreamNetworkState(mobileIPv6UpstreamState);
+        assertTrue(ipv4UpstreamIndices.isEmpty());
+        assertEquals(1, interfaceNames.size());
+        assertTrue(interfaceNames.contains(UPSTREAM_IFINDEX));
+
+        // [3] Mobile IPv4 and IPv6
+        final UpstreamNetworkState mobileDualStackUpstreamState = new UpstreamNetworkState(
+                buildUpstreamLinkProperties(UPSTREAM_IFACE,
+                        true /* IPv4 */, true /* IPv6 */, false /* 464xlat */),
+                new NetworkCapabilities().addTransportType(TRANSPORT_CELLULAR),
+                new Network(TEST_NET_ID));
+        coordinator.updateUpstreamNetworkState(mobileDualStackUpstreamState);
+        verifyIpv4Upstream(ipv4UpstreamIndices, interfaceNames);
+
+        // [4] Lost upstream
+        coordinator.updateUpstreamNetworkState(null);
+        assertTrue(ipv4UpstreamIndices.isEmpty());
+        assertEquals(1, interfaceNames.size());
+        assertTrue(interfaceNames.contains(UPSTREAM_IFINDEX));
+
+        // [5] verify xlat interface
+        // Expect that xlat interface information isn't added to mapping.
+        doReturn(UPSTREAM_XLAT_IFACE_PARAMS).when(mDeps).getInterfaceParams(
+                UPSTREAM_XLAT_IFACE);
+        final UpstreamNetworkState mobile464xlatUpstreamState = new UpstreamNetworkState(
+                buildUpstreamLinkProperties(UPSTREAM_IFACE,
+                        false /* IPv4 */, true /* IPv6 */, true /* 464xlat */),
+                new NetworkCapabilities().addTransportType(TRANSPORT_CELLULAR),
+                new Network(TEST_NET_ID));
+
+        // Need to add a valid IPv4 upstream to verify that xlat interface doesn't support.
+        // Mobile IPv4 only
+        coordinator.updateUpstreamNetworkState(mobileIPv4UpstreamState);
+        verifyIpv4Upstream(ipv4UpstreamIndices, interfaceNames);
+
+        // Mobile IPv6 and xlat
+        // IpServer doesn't add xlat interface mapping via #addUpstreamNameToLookupTable on
+        // S and T devices.
+        coordinator.updateUpstreamNetworkState(mobile464xlatUpstreamState);
+        // Upstream IPv4 address mapping is removed because xlat interface is not supported.
+        assertTrue(ipv4UpstreamIndices.isEmpty());
+        assertEquals(1, interfaceNames.size());
+        assertTrue(interfaceNames.contains(UPSTREAM_IFINDEX));
+
+        // Need to add a valid IPv4 upstream to verify that wifi interface doesn't support.
+        // Mobile IPv4 only
+        coordinator.updateUpstreamNetworkState(mobileIPv4UpstreamState);
+        verifyIpv4Upstream(ipv4UpstreamIndices, interfaceNames);
+
+        // [6] Wifi IPv4 and IPv6
+        // Expect that upstream index map is cleared because ether ip is not supported.
+        coordinator.addUpstreamNameToLookupTable(UPSTREAM_IFINDEX2, UPSTREAM_IFACE2);
+        doReturn(UPSTREAM_IFACE_PARAMS2).when(mDeps).getInterfaceParams(UPSTREAM_IFACE2);
+        final UpstreamNetworkState wifiDualStackUpstreamState = new UpstreamNetworkState(
+                buildUpstreamLinkProperties(UPSTREAM_IFACE2,
+                        true /* IPv4 */, true /* IPv6 */, false /* 464xlat */),
+                new NetworkCapabilities().addTransportType(TRANSPORT_WIFI),
+                new Network(TEST_NET_ID2));
+        coordinator.updateUpstreamNetworkState(wifiDualStackUpstreamState);
+        assertTrue(ipv4UpstreamIndices.isEmpty());
+        assertEquals(2, interfaceNames.size());
+        assertTrue(interfaceNames.contains(UPSTREAM_IFINDEX));
+        assertTrue(interfaceNames.contains(UPSTREAM_IFINDEX2));
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    public void testUpdateUpstreamNetworkState() throws Exception {
+        verifyUpdateUpstreamNetworkState();
+    }
 }
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
index f90b3a4..98a3b1d 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/TetheringTest.java
@@ -261,6 +261,8 @@
 
     private static final int DHCPSERVER_START_TIMEOUT_MS = 1000;
 
+    private static final Network[] NULL_NETWORK = new Network[] {null};
+
     @Mock private ApplicationInfo mApplicationInfo;
     @Mock private Context mContext;
     @Mock private NetworkStatsManager mStatsManager;
@@ -303,6 +305,7 @@
     private MockContentResolver mContentResolver;
     private BroadcastReceiver mBroadcastReceiver;
     private Tethering mTethering;
+    private TestTetheringEventCallback mTetheringEventCallback;
     private PhoneStateListener mPhoneStateListener;
     private InterfaceConfigurationParcel mInterfaceConfiguration;
     private TetheringConfiguration mConfig;
@@ -670,6 +673,7 @@
         verify(mStatsManager, times(1)).registerNetworkStatsProvider(anyString(), any());
         verify(mNetd).registerUnsolicitedEventListener(any());
         verifyDefaultNetworkRequestFiled();
+        mTetheringEventCallback = registerTetheringEventCallback();
 
         final ArgumentCaptor<PhoneStateListener> phoneListenerCaptor =
                 ArgumentCaptor.forClass(PhoneStateListener.class);
@@ -745,6 +749,16 @@
         return request;
     }
 
+    @NonNull
+    private TestTetheringEventCallback registerTetheringEventCallback() {
+        TestTetheringEventCallback callback = new TestTetheringEventCallback();
+        mTethering.registerTetheringEventCallback(callback);
+        mLooper.dispatchAll();
+        // Pull the first event which is filed immediately after the callback registration.
+        callback.expectUpstreamChanged(NULL_NETWORK);
+        return callback;
+    }
+
     @After
     public void tearDown() {
         mServiceContext.unregisterReceiver(mBroadcastReceiver);
@@ -924,9 +938,9 @@
         // tetherMatchingInterfaces() which starts by fetching all interfaces).
         verify(mNetd, times(1)).interfaceGetList();
 
-        // UpstreamNetworkMonitor should receive selected upstream
+        // Event callback should receive selected upstream
         verify(mUpstreamNetworkMonitor, times(1)).getCurrentPreferredUpstream();
-        verify(mUpstreamNetworkMonitor, times(1)).setCurrentUpstream(upstreamState.network);
+        mTetheringEventCallback.expectUpstreamChanged(upstreamState.network);
     }
 
     @Test
@@ -1181,7 +1195,7 @@
         verify(mUpstreamNetworkMonitor, times(1)).getCurrentPreferredUpstream();
         verify(mUpstreamNetworkMonitor, never()).selectPreferredUpstreamType(any());
 
-        verify(mUpstreamNetworkMonitor, times(1)).setCurrentUpstream(upstreamState.network);
+        mTetheringEventCallback.expectUpstreamChanged(upstreamState.network);
     }
 
     private void verifyDisableTryCellWhenTetheringStop(InOrder inOrder) {
@@ -1206,14 +1220,14 @@
         mobile.fakeConnect();
         mCm.makeDefaultNetwork(mobile, BROADCAST_FIRST);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(mobile.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(mobile.networkId);
 
         // Switch upstream to wifi.
         wifi.fakeConnect();
         mCm.makeDefaultNetwork(wifi, BROADCAST_FIRST);
         mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).setTryCell(false);
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(wifi.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(wifi.networkId);
     }
 
     private void verifyAutomaticUpstreamSelection(boolean configAutomatic) throws Exception {
@@ -1230,30 +1244,30 @@
         // Switch upstreams a few times.
         mCm.makeDefaultNetwork(mobile, BROADCAST_FIRST, doDispatchAll);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(mobile.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(mobile.networkId);
 
         mCm.makeDefaultNetwork(wifi, BROADCAST_FIRST, doDispatchAll);
         mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).setTryCell(false);
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(wifi.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(wifi.networkId);
 
         mCm.makeDefaultNetwork(mobile, CALLBACKS_FIRST);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(mobile.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(mobile.networkId);
 
         mCm.makeDefaultNetwork(wifi, CALLBACKS_FIRST);
         mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).setTryCell(false);
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(wifi.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(wifi.networkId);
 
         mCm.makeDefaultNetwork(mobile, CALLBACKS_FIRST, doDispatchAll);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(mobile.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(mobile.networkId);
 
         // Wifi disconnecting should not have any affect since it's not the current upstream.
         wifi.fakeDisconnect();
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor, never()).setCurrentUpstream(any());
+        mTetheringEventCallback.assertNoUpstreamChangeCallback();
 
         // Lose and regain upstream.
         assertTrue(mUpstreamNetworkMonitor.getCurrentPreferredUpstream().linkProperties
@@ -1263,13 +1277,13 @@
         mobile.fakeDisconnect();
         mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).setTryCell(true);
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(null);
+        mTetheringEventCallback.expectUpstreamChanged(NULL_NETWORK);
 
         mobile = new TestNetworkAgent(mCm, buildMobile464xlatUpstreamState());
         mobile.fakeConnect();
         mCm.makeDefaultNetwork(mobile, BROADCAST_FIRST, doDispatchAll);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(mobile.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(mobile.networkId);
 
         // Check the IP addresses to ensure that the upstream is indeed not the same as the previous
         // mobile upstream, even though the netId is (unrealistically) the same.
@@ -1281,13 +1295,13 @@
         mobile.fakeDisconnect();
         mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).setTryCell(true);
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(null);
+        mTetheringEventCallback.expectUpstreamChanged(NULL_NETWORK);
 
         mobile = new TestNetworkAgent(mCm, buildMobileDualStackUpstreamState());
         mobile.fakeConnect();
         mCm.makeDefaultNetwork(mobile, CALLBACKS_FIRST, doDispatchAll);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(mobile.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(mobile.networkId);
 
         assertTrue(mUpstreamNetworkMonitor.getCurrentPreferredUpstream().linkProperties
                 .hasIPv4Address());
@@ -1328,27 +1342,27 @@
         mLooper.dispatchAll();
         mCm.makeDefaultNetwork(mobile, CALLBACKS_FIRST, null);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(mobile.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(mobile.networkId);
 
         wifi.fakeConnect();
         mLooper.dispatchAll();
         mCm.makeDefaultNetwork(wifi, CALLBACKS_FIRST, null);
         mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).setTryCell(false);
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(wifi.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(wifi.networkId);
 
         verifyDisableTryCellWhenTetheringStop(inOrder);
     }
 
     private void verifyWifiUpstreamAndUnregisterDunCallback(@NonNull final InOrder inOrder,
-            @NonNull final TestNetworkAgent wifi,
-            @NonNull final NetworkCallback currentDunCallack) throws Exception {
+            @NonNull final TestNetworkAgent wifi, @NonNull final NetworkCallback currentDunCallack)
+            throws Exception {
         assertNotNull(currentDunCallack);
 
         inOrder.verify(mUpstreamNetworkMonitor).setTryCell(false);
         inOrder.verify(mCm).unregisterNetworkCallback(eq(currentDunCallack));
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(wifi.networkId);
-        inOrder.verify(mUpstreamNetworkMonitor, never()).setCurrentUpstream(any());
+        mTetheringEventCallback.expectUpstreamChanged(wifi.networkId);
+        mTetheringEventCallback.assertNoUpstreamChangeCallback();
     }
 
     @Nullable
@@ -1363,11 +1377,11 @@
                     captor.capture());
             dunNetworkCallback = captor.getValue();
         }
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(null);
+        mTetheringEventCallback.expectUpstreamChanged(NULL_NETWORK);
         final Runnable doDispatchAll = () -> mLooper.dispatchAll();
         dun.fakeConnect(CALLBACKS_FIRST, doDispatchAll);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(dun.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(dun.networkId);
 
         if (needToRequestNetwork) {
             assertNotNull(dunNetworkCallback);
@@ -1484,11 +1498,11 @@
         dun.fakeDisconnect(CALLBACKS_FIRST, doDispatchAll);
         mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).setTryCell(true);
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(null);
+        mTetheringEventCallback.expectUpstreamChanged(NULL_NETWORK);
         inOrder.verify(mCm, never()).unregisterNetworkCallback(any(NetworkCallback.class));
         dun.fakeConnect(CALLBACKS_FIRST, doDispatchAll);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(dun.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(dun.networkId);
 
         verifyDisableTryCellWhenTetheringStop(inOrder);
     }
@@ -1543,8 +1557,8 @@
         // automatic mode would request dun again and choose it as upstream.
         mCm.makeDefaultNetwork(mobile, CALLBACKS_FIRST);
         mLooper.dispatchAll();
-        final NetworkCallback dunNetworkCallback2 =
-                verifyDunUpstream(inOrder, dun, true /* needToRequestNetwork */);
+        final NetworkCallback dunNetworkCallback2 = verifyDunUpstream(inOrder, dun,
+                true /* needToRequestNetwork */);
 
         // [3] When default network switch to wifi and mobile is still connected,
         // unregister dun request and choose wifi as upstream.
@@ -1556,7 +1570,7 @@
         final Runnable doDispatchAll = () -> mLooper.dispatchAll();
         mobile.fakeDisconnect(CALLBACKS_FIRST, doDispatchAll);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor, never()).setCurrentUpstream(any());
+        mTetheringEventCallback.assertNoUpstreamChangeCallback();
 
         verifyDisableTryCellWhenTetheringStop(inOrder);
     }
@@ -1627,19 +1641,19 @@
         final Runnable doDispatchAll = () -> mLooper.dispatchAll();
         dun.fakeConnect(CALLBACKS_FIRST, doDispatchAll);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(dun.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(dun.networkId);
 
         // [6] When mobile is connected and default network switch to mobile, keep dun as upstream.
         mobile.fakeConnect();
         mCm.makeDefaultNetwork(mobile, CALLBACKS_FIRST);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor, never()).setCurrentUpstream(any());
+        mTetheringEventCallback.assertNoUpstreamChangeCallback();
 
         // [7] When mobile is disconnected, keep dun as upstream.
         mCm.makeDefaultNetwork(null, CALLBACKS_FIRST, doDispatchAll);
         mobile.fakeDisconnect(CALLBACKS_FIRST, doDispatchAll);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor, never()).setCurrentUpstream(any());
+        mTetheringEventCallback.assertNoUpstreamChangeCallback();
 
         verifyDisableTryCellWhenTetheringStop(inOrder);
     }
@@ -1670,7 +1684,7 @@
         final Runnable doDispatchAll = () -> mLooper.dispatchAll();
         dun.fakeConnect(CALLBACKS_FIRST, doDispatchAll);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(dun.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(dun.networkId);
 
         // [8] Lose and regain upstream again.
         dun.fakeDisconnect(CALLBACKS_FIRST, doDispatchAll);
@@ -1714,7 +1728,7 @@
         final Runnable doDispatchAll = () -> mLooper.dispatchAll();
         dun.fakeConnect(CALLBACKS_FIRST, doDispatchAll);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(dun.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(dun.networkId);
 
         // [9] When wifi is connected and default network switch to wifi, unregister dun request
         // and choose wifi as upstream. When dun is disconnected, keep wifi as upstream.
@@ -1724,7 +1738,7 @@
         verifyWifiUpstreamAndUnregisterDunCallback(inOrder, wifi, dunNetworkCallback);
         dun.fakeDisconnect(CALLBACKS_FIRST, doDispatchAll);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor, never()).setCurrentUpstream(any());
+        mTetheringEventCallback.assertNoUpstreamChangeCallback();
 
         // [10] When mobile and mobile are connected and default network switch to mobile
         // (may have low signal), automatic mode would request dun again and choose it as
@@ -1752,7 +1766,7 @@
         mCm.makeDefaultNetwork(mobile, CALLBACKS_FIRST);
         mLooper.dispatchAll();
         inOrder.verify(mUpstreamNetworkMonitor).setTryCell(false);
-        inOrder.verify(mUpstreamNetworkMonitor, never()).setCurrentUpstream(any());
+        mTetheringEventCallback.assertNoUpstreamChangeCallback();
         // BUG: when wifi disconnect, the dun request would not be filed again because wifi is
         // no longer be default network which do not have CONNECTIVIY_ACTION broadcast.
         wifi.fakeDisconnect();
@@ -1799,20 +1813,21 @@
     }
 
     private void chooseDunUpstreamTestCommon(final boolean automatic, InOrder inOrder,
-            TestNetworkAgent mobile, TestNetworkAgent wifi, TestNetworkAgent dun) throws Exception {
+            TestNetworkAgent mobile, TestNetworkAgent wifi, TestNetworkAgent dun)
+            throws Exception {
         final NetworkCallback dunNetworkCallback = setupDunUpstreamTest(automatic, inOrder);
 
         // Pretend cellular connected and expect the upstream to be not set.
         mobile.fakeConnect();
         mCm.makeDefaultNetwork(mobile, BROADCAST_FIRST);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor, never()).setCurrentUpstream(mobile.networkId);
+        mTetheringEventCallback.assertNoUpstreamChangeCallback();
 
         // Pretend dun connected and expect choose dun as upstream.
         final Runnable doDispatchAll = () -> mLooper.dispatchAll();
         dun.fakeConnect(BROADCAST_FIRST, doDispatchAll);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor).setCurrentUpstream(dun.networkId);
+        mTetheringEventCallback.expectUpstreamChanged(dun.networkId);
 
         // When wifi connected, unregister dun request and choose wifi as upstream.
         wifi.fakeConnect();
@@ -1821,7 +1836,7 @@
         verifyWifiUpstreamAndUnregisterDunCallback(inOrder, wifi, dunNetworkCallback);
         dun.fakeDisconnect(BROADCAST_FIRST, doDispatchAll);
         mLooper.dispatchAll();
-        inOrder.verify(mUpstreamNetworkMonitor, never()).setCurrentUpstream(any());
+        mTetheringEventCallback.assertNoUpstreamChangeCallback();
     }
 
     private void runNcmTethering() {
@@ -2268,7 +2283,7 @@
         mTethering.registerTetheringEventCallback(callback);
         mLooper.dispatchAll();
         callback.expectTetheredClientChanged(Collections.emptyList());
-        callback.expectUpstreamChanged(new Network[] {null});
+        callback.expectUpstreamChanged(NULL_NETWORK);
         callback.expectConfigurationChanged(
                 mTethering.getTetheringConfiguration().toStableParcelable());
         TetherStatesParcel tetherState = callback.pollTetherStatesChanged();
@@ -2316,7 +2331,7 @@
         tetherState = callback2.pollTetherStatesChanged();
         assertArrayEquals(tetherState.availableList, new TetheringInterface[] {wifiIface});
         mLooper.dispatchAll();
-        callback2.expectUpstreamChanged(new Network[] {null});
+        callback2.expectUpstreamChanged(NULL_NETWORK);
         callback2.expectOffloadStatusChanged(TETHER_HARDWARE_OFFLOAD_STOPPED);
         callback.assertNoCallback();
     }
@@ -2663,12 +2678,19 @@
     public void testUpstreamNetworkChanged() {
         final Tethering.TetherMainSM stateMachine = (Tethering.TetherMainSM)
                 mTetheringDependencies.mUpstreamNetworkMonitorSM;
+        // Gain upstream.
         final UpstreamNetworkState upstreamState = buildMobileIPv4UpstreamState();
         initTetheringUpstream(upstreamState);
         stateMachine.chooseUpstreamType(true);
+        mTetheringEventCallback.expectUpstreamChanged(upstreamState.network);
+        verify(mNotificationUpdater)
+                .onUpstreamCapabilitiesChanged(upstreamState.networkCapabilities);
 
-        verify(mUpstreamNetworkMonitor, times(1)).setCurrentUpstream(eq(upstreamState.network));
-        verify(mNotificationUpdater, times(1)).onUpstreamCapabilitiesChanged(any());
+        // Lose upstream.
+        initTetheringUpstream(null);
+        stateMachine.chooseUpstreamType(true);
+        mTetheringEventCallback.expectUpstreamChanged(NULL_NETWORK);
+        verify(mNotificationUpdater).onUpstreamCapabilitiesChanged(null);
     }
 
     @Test
@@ -2682,7 +2704,8 @@
         stateMachine.handleUpstreamNetworkMonitorCallback(EVENT_ON_CAPABILITIES, upstreamState);
         // Should have two onUpstreamCapabilitiesChanged().
         // One is called by reportUpstreamChanged(). One is called by EVENT_ON_CAPABILITIES.
-        verify(mNotificationUpdater, times(2)).onUpstreamCapabilitiesChanged(any());
+        verify(mNotificationUpdater, times(2))
+                .onUpstreamCapabilitiesChanged(upstreamState.networkCapabilities);
         reset(mNotificationUpdater);
 
         // Verify that onUpstreamCapabilitiesChanged won't be called if not current upstream network
@@ -2695,6 +2718,27 @@
     }
 
     @Test
+    public void testUpstreamCapabilitiesChanged_startStopTethering() throws Exception {
+        final TestNetworkAgent wifi = new TestNetworkAgent(mCm, buildWifiUpstreamState());
+
+        // Start USB tethering with no current upstream.
+        prepareUsbTethering();
+        sendUsbBroadcast(true, true, TETHER_USB_RNDIS_FUNCTION);
+
+        // Pretend wifi connected and expect the upstream to be set.
+        wifi.fakeConnect();
+        mCm.makeDefaultNetwork(wifi, CALLBACKS_FIRST);
+        mLooper.dispatchAll();
+        verify(mNotificationUpdater).onUpstreamCapabilitiesChanged(
+                wifi.networkCapabilities);
+
+        // Stop tethering.
+        // Expect that TetherModeAliveState#exit sends capabilities change notification to null.
+        runStopUSBTethering();
+        verify(mNotificationUpdater).onUpstreamCapabilitiesChanged(null);
+    }
+
+    @Test
     public void testDumpTetheringLog() throws Exception {
         final FileDescriptor mockFd = mock(FileDescriptor.class);
         final PrintWriter mockPw = mock(PrintWriter.class);
diff --git a/framework-t/api/current.txt b/framework-t/api/current.txt
index 5532853..86745d4 100644
--- a/framework-t/api/current.txt
+++ b/framework-t/api/current.txt
@@ -228,8 +228,8 @@
   }
 
   public static interface NsdManager.ResolveListener {
+    method public default void onResolutionStopped(@NonNull android.net.nsd.NsdServiceInfo);
     method public void onResolveFailed(android.net.nsd.NsdServiceInfo, int);
-    method public default void onResolveStopped(@NonNull android.net.nsd.NsdServiceInfo);
     method public void onServiceResolved(android.net.nsd.NsdServiceInfo);
     method public default void onStopResolutionFailed(@NonNull android.net.nsd.NsdServiceInfo, int);
   }
diff --git a/framework-t/src/android/net/nsd/NsdManager.java b/framework-t/src/android/net/nsd/NsdManager.java
index 122e3a0..e38ae8e 100644
--- a/framework-t/src/android/net/nsd/NsdManager.java
+++ b/framework-t/src/android/net/nsd/NsdManager.java
@@ -767,18 +767,18 @@
          * Called on the internal thread or with an executor passed to
          * {@link NsdManager#resolveService} to report the resolution was stopped.
          *
-         * A stop resolution operation would call either onResolveStopped or onStopResolutionFailed
-         * once based on the result.
+         * A stop resolution operation would call either onResolutionStopped or
+         * onStopResolutionFailed once based on the result.
          */
-        default void onResolveStopped(@NonNull NsdServiceInfo serviceInfo) { }
+        default void onResolutionStopped(@NonNull NsdServiceInfo serviceInfo) { }
 
         /**
          * Called once on the internal thread or with an executor passed to
          * {@link NsdManager#resolveService} to report that stopping resolution failed with an
          * error.
          *
-         * A stop resolution operation would call either onResolveStopped or onStopResolutionFailed
-         * once based on the result.
+         * A stop resolution operation would call either onResolutionStopped or
+         * onStopResolutionFailed once based on the result.
          */
         default void onStopResolutionFailed(@NonNull NsdServiceInfo serviceInfo,
                 @StopOperationFailureCode int errorCode) { }
@@ -929,7 +929,7 @@
                     break;
                 case STOP_RESOLUTION_SUCCEEDED:
                     removeListener(key);
-                    executor.execute(() -> ((ResolveListener) listener).onResolveStopped(
+                    executor.execute(() -> ((ResolveListener) listener).onResolutionStopped(
                             ns));
                     break;
                 case REGISTER_SERVICE_CALLBACK_FAILED:
@@ -1301,7 +1301,7 @@
     /**
      * Stop service resolution initiated with {@link #resolveService}.
      *
-     * A successful stop is notified with a call to {@link ResolveListener#onResolveStopped}.
+     * A successful stop is notified with a call to {@link ResolveListener#onResolutionStopped}.
      *
      * <p> Upon failure to stop service resolution for example if resolution is done or the
      * requester stops resolution repeatedly, the application is notified
diff --git a/framework/src/android/net/ConnectivityManager.java b/framework/src/android/net/ConnectivityManager.java
index 4224da9..17389b4 100644
--- a/framework/src/android/net/ConnectivityManager.java
+++ b/framework/src/android/net/ConnectivityManager.java
@@ -2279,23 +2279,12 @@
         private final ISocketKeepaliveCallback mCallback;
         private final ExecutorService mExecutor;
 
-        private volatile Integer mSlot;
-
         @UnsupportedAppUsage(maxTargetSdk = Build.VERSION_CODES.R, trackingBug = 170729553)
         public void stop() {
             try {
                 mExecutor.execute(() -> {
                     try {
-                        if (mSlot != null) {
-                            // TODO : this is incorrect, because in the presence of auto on/off
-                            // keepalive the slot associated with this keepalive can have
-                            // changed. Also, this actually causes another problem where some other
-                            // app might stop your keepalive if it just knows the network and
-                            // the slot and goes through the trouble of grabbing the aidl object.
-                            // This code should use the callback to identify what keepalive to
-                            // stop instead.
-                            mService.stopKeepalive(mNetwork, mSlot);
-                        }
+                        mService.stopKeepalive(mCallback);
                     } catch (RemoteException e) {
                         Log.e(TAG, "Error stopping packet keepalive: ", e);
                         throw e.rethrowFromSystemServer();
@@ -2313,11 +2302,10 @@
             mExecutor = Executors.newSingleThreadExecutor();
             mCallback = new ISocketKeepaliveCallback.Stub() {
                 @Override
-                public void onStarted(int slot) {
+                public void onStarted() {
                     final long token = Binder.clearCallingIdentity();
                     try {
                         mExecutor.execute(() -> {
-                            mSlot = slot;
                             callback.onStarted();
                         });
                     } finally {
@@ -2330,7 +2318,6 @@
                     final long token = Binder.clearCallingIdentity();
                     try {
                         mExecutor.execute(() -> {
-                            mSlot = null;
                             callback.onStopped();
                         });
                     } finally {
@@ -2344,7 +2331,6 @@
                     final long token = Binder.clearCallingIdentity();
                     try {
                         mExecutor.execute(() -> {
-                            mSlot = null;
                             callback.onError(error);
                         });
                     } finally {
diff --git a/framework/src/android/net/IConnectivityManager.aidl b/framework/src/android/net/IConnectivityManager.aidl
index 7db231e..acbc31e 100644
--- a/framework/src/android/net/IConnectivityManager.aidl
+++ b/framework/src/android/net/IConnectivityManager.aidl
@@ -193,7 +193,7 @@
     void startTcpKeepalive(in Network network, in ParcelFileDescriptor pfd, int intervalSeconds,
             in ISocketKeepaliveCallback cb);
 
-    void stopKeepalive(in Network network, int slot);
+    void stopKeepalive(in ISocketKeepaliveCallback cb);
 
     String getCaptivePortalServerUrl();
 
diff --git a/framework/src/android/net/ISocketKeepaliveCallback.aidl b/framework/src/android/net/ISocketKeepaliveCallback.aidl
index 020fbca..1240e37 100644
--- a/framework/src/android/net/ISocketKeepaliveCallback.aidl
+++ b/framework/src/android/net/ISocketKeepaliveCallback.aidl
@@ -24,7 +24,7 @@
 oneway interface ISocketKeepaliveCallback
 {
     /** The keepalive was successfully started. */
-    void onStarted(int slot);
+    void onStarted();
     /** The keepalive was successfully stopped. */
     void onStopped();
     /** The keepalive was stopped because of an error. */
diff --git a/framework/src/android/net/NattSocketKeepalive.java b/framework/src/android/net/NattSocketKeepalive.java
index 4d45e70..77137f4 100644
--- a/framework/src/android/net/NattSocketKeepalive.java
+++ b/framework/src/android/net/NattSocketKeepalive.java
@@ -91,9 +91,7 @@
     protected void stopImpl() {
         mExecutor.execute(() -> {
             try {
-                if (mSlot != null) {
-                    mService.stopKeepalive(mNetwork, mSlot);
-                }
+                mService.stopKeepalive(mCallback);
             } catch (RemoteException e) {
                 Log.e(TAG, "Error stopping socket keepalive: ", e);
                 throw e.rethrowFromSystemServer();
diff --git a/framework/src/android/net/NetworkAgent.java b/framework/src/android/net/NetworkAgent.java
index 62e4fe1..8fe20de 100644
--- a/framework/src/android/net/NetworkAgent.java
+++ b/framework/src/android/net/NetworkAgent.java
@@ -291,7 +291,9 @@
     /**
      * Requests that the specified keepalive packet be stopped.
      *
-     * arg1 = hardware slot number of the keepalive to stop.
+     * arg1 = unused
+     * arg2 = error code (SUCCESS)
+     * obj = callback to identify the keepalive
      *
      * Also used internally by ConnectivityService / KeepaliveTracker, with different semantics.
      * @hide
@@ -491,7 +493,7 @@
      * TCP sockets are open over a VPN. The system will check periodically for presence of
      * such open sockets, and this message is what triggers the re-evaluation.
      *
-     * obj = AutomaticOnOffKeepaliveObject.
+     * obj = A Binder object associated with the keepalive.
      * @hide
      */
     public static final int CMD_MONITOR_AUTOMATIC_KEEPALIVE = BASE + 30;
diff --git a/framework/src/android/net/SocketKeepalive.java b/framework/src/android/net/SocketKeepalive.java
index 90e5e9b..2911ce7 100644
--- a/framework/src/android/net/SocketKeepalive.java
+++ b/framework/src/android/net/SocketKeepalive.java
@@ -21,7 +21,6 @@
 import android.annotation.IntDef;
 import android.annotation.IntRange;
 import android.annotation.NonNull;
-import android.annotation.Nullable;
 import android.annotation.SystemApi;
 import android.os.Binder;
 import android.os.ParcelFileDescriptor;
@@ -249,9 +248,6 @@
     @NonNull protected final Executor mExecutor;
     /** @hide */
     @NonNull protected final ISocketKeepaliveCallback mCallback;
-    // TODO: remove slot since mCallback could be used to identify which keepalive to stop.
-    /** @hide */
-    @Nullable protected Integer mSlot;
 
     /** @hide */
     public SocketKeepalive(@NonNull IConnectivityManager service, @NonNull Network network,
@@ -263,11 +259,10 @@
         mExecutor = executor;
         mCallback = new ISocketKeepaliveCallback.Stub() {
             @Override
-            public void onStarted(int slot) {
+            public void onStarted() {
                 final long token = Binder.clearCallingIdentity();
                 try {
                     mExecutor.execute(() -> {
-                        mSlot = slot;
                         callback.onStarted();
                     });
                 } finally {
@@ -280,7 +275,6 @@
                 final long token = Binder.clearCallingIdentity();
                 try {
                     executor.execute(() -> {
-                        mSlot = null;
                         callback.onStopped();
                     });
                 } finally {
@@ -293,7 +287,6 @@
                 final long token = Binder.clearCallingIdentity();
                 try {
                     executor.execute(() -> {
-                        mSlot = null;
                         callback.onError(error);
                     });
                 } finally {
@@ -306,7 +299,6 @@
                 final long token = Binder.clearCallingIdentity();
                 try {
                     executor.execute(() -> {
-                        mSlot = null;
                         callback.onDataReceived();
                     });
                 } finally {
diff --git a/framework/src/android/net/TcpSocketKeepalive.java b/framework/src/android/net/TcpSocketKeepalive.java
index 51d805e..cda5830 100644
--- a/framework/src/android/net/TcpSocketKeepalive.java
+++ b/framework/src/android/net/TcpSocketKeepalive.java
@@ -69,9 +69,7 @@
     protected void stopImpl() {
         mExecutor.execute(() -> {
             try {
-                if (mSlot != null) {
-                    mService.stopKeepalive(mNetwork, mSlot);
-                }
+                mService.stopKeepalive(mCallback);
             } catch (RemoteException e) {
                 Log.e(TAG, "Error stopping packet keepalive: ", e);
                 throw e.rethrowFromSystemServer();
diff --git a/service-t/native/libs/libnetworkstats/BpfNetworkStats.cpp b/service-t/native/libs/libnetworkstats/BpfNetworkStats.cpp
index 122c2d4..4fbc5f4 100644
--- a/service-t/native/libs/libnetworkstats/BpfNetworkStats.cpp
+++ b/service-t/native/libs/libnetworkstats/BpfNetworkStats.cpp
@@ -247,10 +247,6 @@
     return parseBpfNetworkStatsDevInternal(lines, ifaceStatsMap, ifaceIndexNameMap);
 }
 
-uint64_t combineUidTag(const uid_t uid, const uint32_t tag) {
-    return (uint64_t)uid << 32 | tag;
-}
-
 void groupNetworkStats(std::vector<stats_line>* lines) {
     if (lines->size() <= 1) return;
     std::sort(lines->begin(), lines->end());
diff --git a/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp b/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
index 4c37b8d..aeadb4a 100644
--- a/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
+++ b/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
@@ -50,6 +50,7 @@
 void NetworkTraceHandler::InitPerfettoTracing() {
   perfetto::TracingInitArgs args = {};
   args.backends |= perfetto::kSystemBackend;
+  args.enable_system_consumer = false;
   perfetto::Tracing::Initialize(args);
   NetworkTraceHandler::RegisterDataSource();
 }
diff --git a/service-t/native/libs/libnetworkstats/NetworkTraceHandlerTest.cpp b/service-t/native/libs/libnetworkstats/NetworkTraceHandlerTest.cpp
index 760ae91..560194f 100644
--- a/service-t/native/libs/libnetworkstats/NetworkTraceHandlerTest.cpp
+++ b/service-t/native/libs/libnetworkstats/NetworkTraceHandlerTest.cpp
@@ -87,7 +87,10 @@
  protected:
   void SetUp() {
     if (access(PACKET_TRACE_RINGBUF_PATH, R_OK)) {
-      GTEST_SKIP() << "Network tracing is not enabled/loaded on this build";
+      GTEST_SKIP() << "Network tracing is not enabled/loaded on this build.";
+    }
+    if (sizeof(void*) != 8) {
+      GTEST_SKIP() << "Network tracing requires 64-bit build.";
     }
   }
 };
diff --git a/service-t/src/com/android/server/IpSecService.java b/service-t/src/com/android/server/IpSecService.java
index 9e71eb3..a884840 100644
--- a/service-t/src/com/android/server/IpSecService.java
+++ b/service-t/src/com/android/server/IpSecService.java
@@ -17,6 +17,7 @@
 package com.android.server;
 
 import static android.Manifest.permission.DUMP;
+import static android.Manifest.permission.NETWORK_SETTINGS;
 import static android.net.IpSecManager.FEATURE_IPSEC_TUNNEL_MIGRATION;
 import static android.net.IpSecManager.INVALID_RESOURCE_ID;
 import static android.system.OsConstants.AF_INET;
@@ -65,6 +66,7 @@
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.Preconditions;
+import com.android.modules.utils.build.SdkLevel;
 import com.android.net.module.util.BinderUtils;
 import com.android.net.module.util.NetdUtils;
 import com.android.net.module.util.PermissionUtils;
@@ -102,6 +104,7 @@
 
     private static final int NETD_FETCH_TIMEOUT_MS = 5000; // ms
     private static final InetAddress INADDR_ANY;
+    private static final InetAddress IN6ADDR_ANY;
 
     @VisibleForTesting static final int MAX_PORT_BIND_ATTEMPTS = 10;
 
@@ -110,6 +113,8 @@
     static {
         try {
             INADDR_ANY = InetAddress.getByAddress(new byte[] {0, 0, 0, 0});
+            IN6ADDR_ANY = InetAddress.getByAddress(
+                    new byte[] {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
         } catch (UnknownHostException e) {
             throw new RuntimeException(e);
         }
@@ -1013,11 +1018,13 @@
     private final class EncapSocketRecord extends OwnedResourceRecord {
         private FileDescriptor mSocket;
         private final int mPort;
+        private final int mFamily;  // TODO: what about IPV6_ADDRFORM?
 
-        EncapSocketRecord(int resourceId, FileDescriptor socket, int port) {
+        EncapSocketRecord(int resourceId, FileDescriptor socket, int port, int family) {
             super(resourceId);
             mSocket = socket;
             mPort = port;
+            mFamily = family;
         }
 
         /** always guarded by IpSecService#this */
@@ -1038,6 +1045,10 @@
             return mSocket;
         }
 
+        public int getFamily() {
+            return mFamily;
+        }
+
         @Override
         protected ResourceTracker getResourceTracker() {
             return getUserRecord().mSocketQuotaTracker;
@@ -1210,15 +1221,16 @@
      * and re-binding, during which the system could *technically* hand that port out to someone
      * else.
      */
-    private int bindToRandomPort(FileDescriptor sockFd) throws IOException {
+    private int bindToRandomPort(FileDescriptor sockFd, int family, InetAddress localAddr)
+            throws IOException {
         for (int i = MAX_PORT_BIND_ATTEMPTS; i > 0; i--) {
             try {
-                FileDescriptor probeSocket = Os.socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
-                Os.bind(probeSocket, INADDR_ANY, 0);
+                FileDescriptor probeSocket = Os.socket(family, SOCK_DGRAM, IPPROTO_UDP);
+                Os.bind(probeSocket, localAddr, 0);
                 int port = ((InetSocketAddress) Os.getsockname(probeSocket)).getPort();
                 Os.close(probeSocket);
                 Log.v(TAG, "Binding to port " + port);
-                Os.bind(sockFd, INADDR_ANY, port);
+                Os.bind(sockFd, localAddr, port);
                 return port;
             } catch (ErrnoException e) {
                 // Someone miraculously claimed the port just after we closed probeSocket.
@@ -1260,6 +1272,19 @@
     @Override
     public synchronized IpSecUdpEncapResponse openUdpEncapsulationSocket(int port, IBinder binder)
             throws RemoteException {
+        // Experimental support for IPv6 UDP encap.
+        final int family;
+        final InetAddress localAddr;
+        if (SdkLevel.isAtLeastU() && port >= 65536) {
+            PermissionUtils.enforceNetworkStackPermissionOr(mContext, NETWORK_SETTINGS);
+            port -= 65536;
+            family = AF_INET6;
+            localAddr = IN6ADDR_ANY;
+        } else {
+            family = AF_INET;
+            localAddr = INADDR_ANY;
+        }
+
         if (port != 0 && (port < FREE_PORT_MIN || port > PORT_MAX)) {
             throw new IllegalArgumentException(
                     "Specified port number must be a valid non-reserved UDP port");
@@ -1278,7 +1303,7 @@
 
             FileDescriptor sockFd = null;
             try {
-                sockFd = Os.socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
+                sockFd = Os.socket(family, SOCK_DGRAM, IPPROTO_UDP);
                 pFd = ParcelFileDescriptor.dup(sockFd);
             } finally {
                 IoUtils.closeQuietly(sockFd);
@@ -1295,15 +1320,16 @@
             mNetd.ipSecSetEncapSocketOwner(pFd, callingUid);
             if (port != 0) {
                 Log.v(TAG, "Binding to port " + port);
-                Os.bind(pFd.getFileDescriptor(), INADDR_ANY, port);
+                Os.bind(pFd.getFileDescriptor(), localAddr, port);
             } else {
-                port = bindToRandomPort(pFd.getFileDescriptor());
+                port = bindToRandomPort(pFd.getFileDescriptor(), family, localAddr);
             }
 
             userRecord.mEncapSocketRecords.put(
                     resourceId,
                     new RefcountedResource<EncapSocketRecord>(
-                            new EncapSocketRecord(resourceId, pFd.getFileDescriptor(), port),
+                            new EncapSocketRecord(resourceId, pFd.getFileDescriptor(), port,
+                                    family),
                             binder));
             return new IpSecUdpEncapResponse(IpSecManager.Status.OK, resourceId, port,
                     pFd.getFileDescriptor());
@@ -1580,6 +1606,7 @@
      */
     private void checkIpSecConfig(IpSecConfig config) {
         UserRecord userRecord = mUserResourceTracker.getUserRecord(Binder.getCallingUid());
+        EncapSocketRecord encapSocketRecord = null;
 
         switch (config.getEncapType()) {
             case IpSecTransform.ENCAP_NONE:
@@ -1587,7 +1614,7 @@
             case IpSecTransform.ENCAP_ESPINUDP:
             case IpSecTransform.ENCAP_ESPINUDP_NON_IKE:
                 // Retrieve encap socket record; will throw IllegalArgumentException if not found
-                userRecord.mEncapSocketRecords.getResourceOrThrow(
+                encapSocketRecord = userRecord.mEncapSocketRecords.getResourceOrThrow(
                         config.getEncapSocketResourceId());
 
                 int port = config.getEncapRemotePort();
@@ -1641,10 +1668,9 @@
                             + ") have different address families.");
         }
 
-        // Throw an error if UDP Encapsulation is not used in IPv4.
-        if (config.getEncapType() != IpSecTransform.ENCAP_NONE && sourceFamily != AF_INET) {
+        if (encapSocketRecord != null && encapSocketRecord.getFamily() != destinationFamily) {
             throw new IllegalArgumentException(
-                    "UDP Encapsulation is not supported for this address family");
+                    "UDP encapsulation socket and destination address families must match");
         }
 
         switch (config.getMode()) {
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 5dcf860..4ad39e1 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -52,7 +52,6 @@
 import android.util.Log;
 import android.util.Pair;
 import android.util.SparseArray;
-import android.util.SparseIntArray;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.State;
@@ -119,13 +118,15 @@
     private final NsdStateMachine mNsdStateMachine;
     private final MDnsManager mMDnsManager;
     private final MDnsEventCallback mMDnsEventCallback;
-    @Nullable
+    @NonNull
+    private final Dependencies mDeps;
+    @NonNull
     private final MdnsMultinetworkSocketClient mMdnsSocketClient;
-    @Nullable
+    @NonNull
     private final MdnsDiscoveryManager mMdnsDiscoveryManager;
-    @Nullable
+    @NonNull
     private final MdnsSocketProvider mMdnsSocketProvider;
-    @Nullable
+    @NonNull
     private final MdnsAdvertiser mAdvertiser;
     // WARNING : Accessing these values in any thread is not safe, it must only be changed in the
     // state machine thread. If change this outside state machine, it will need to introduce
@@ -312,21 +313,14 @@
             mIsMonitoringSocketsStarted = true;
         }
 
-        private void maybeStopMonitoringSockets() {
-            if (!mIsMonitoringSocketsStarted) {
-                if (DBG) Log.d(TAG, "Socket monitoring has not been started.");
-                return;
-            }
+        private void maybeStopMonitoringSocketsIfNoActiveRequest() {
+            if (!mIsMonitoringSocketsStarted) return;
+            if (isAnyRequestActive()) return;
+
             mMdnsSocketProvider.stopMonitoringSockets();
             mIsMonitoringSocketsStarted = false;
         }
 
-        private void maybeStopMonitoringSocketsIfNoActiveRequest() {
-            if (!isAnyRequestActive()) {
-                maybeStopMonitoringSockets();
-            }
-        }
-
         NsdStateMachine(String name, Handler handler) {
             super(name, handler);
             addState(mDefaultState);
@@ -358,17 +352,12 @@
                         final NsdServiceConnector connector = (NsdServiceConnector) msg.obj;
                         cInfo = mClients.remove(connector);
                         if (cInfo != null) {
-                            if (mMdnsDiscoveryManager != null) {
-                                cInfo.unregisterAllListeners();
-                            }
                             cInfo.expungeAllRequests();
-                            if (cInfo.isLegacy()) {
+                            if (cInfo.isPreSClient()) {
                                 mLegacyClientCount -= 1;
                             }
                         }
-                        if (mMdnsDiscoveryManager != null || mAdvertiser != null) {
-                            maybeStopMonitoringSocketsIfNoActiveRequest();
-                        }
+                        maybeStopMonitoringSocketsIfNoActiveRequest();
                         maybeScheduleStop();
                         break;
                     case NsdManager.DISCOVER_SERVICES:
@@ -429,7 +418,7 @@
                         cInfo = getClientInfoForReply(msg);
                         if (cInfo != null) {
                             cancelStop();
-                            cInfo.setLegacy();
+                            cInfo.setPreSClient();
                             mLegacyClientCount += 1;
                             maybeStartDaemon();
                         }
@@ -461,41 +450,45 @@
             }
 
             private boolean requestLimitReached(ClientInfo clientInfo) {
-                if (clientInfo.mClientIds.size() >= ClientInfo.MAX_LIMIT) {
+                if (clientInfo.mClientRequests.size() >= ClientInfo.MAX_LIMIT) {
                     if (DBG) Log.d(TAG, "Exceeded max outstanding requests " + clientInfo);
                     return true;
                 }
                 return false;
             }
 
-            private void storeRequestMap(int clientId, int globalId, ClientInfo clientInfo, int what) {
-                clientInfo.mClientIds.put(clientId, globalId);
-                clientInfo.mClientRequests.put(clientId, what);
+            private void storeLegacyRequestMap(int clientId, int globalId, ClientInfo clientInfo,
+                    int what) {
+                clientInfo.mClientRequests.put(clientId, new LegacyClientRequest(globalId, what));
                 mIdToClientInfoMap.put(globalId, clientInfo);
                 // Remove the cleanup event because here comes a new request.
                 cancelStop();
             }
 
-            private void removeRequestMap(int clientId, int globalId, ClientInfo clientInfo) {
-                clientInfo.mClientIds.delete(clientId);
-                clientInfo.mClientRequests.delete(clientId);
-                mIdToClientInfoMap.remove(globalId);
-                maybeScheduleStop();
-                maybeStopMonitoringSocketsIfNoActiveRequest();
-            }
-
-            private void storeListenerMap(int clientId, int transactionId, MdnsListener listener,
+            private void storeAdvertiserRequestMap(int clientId, int globalId,
                     ClientInfo clientInfo) {
-                clientInfo.mClientIds.put(clientId, transactionId);
-                clientInfo.mListeners.put(clientId, listener);
-                mIdToClientInfoMap.put(transactionId, clientInfo);
+                clientInfo.mClientRequests.put(clientId, new AdvertiserClientRequest(globalId));
+                mIdToClientInfoMap.put(globalId, clientInfo);
             }
 
-            private void removeListenerMap(int clientId, int transactionId, ClientInfo clientInfo) {
-                clientInfo.mClientIds.delete(clientId);
-                clientInfo.mListeners.delete(clientId);
-                mIdToClientInfoMap.remove(transactionId);
-                maybeStopMonitoringSocketsIfNoActiveRequest();
+            private void removeRequestMap(int clientId, int globalId, ClientInfo clientInfo) {
+                final ClientRequest existing = clientInfo.mClientRequests.get(clientId);
+                if (existing == null) return;
+                clientInfo.mClientRequests.remove(clientId);
+                mIdToClientInfoMap.remove(globalId);
+
+                if (existing instanceof LegacyClientRequest) {
+                    maybeScheduleStop();
+                } else {
+                    maybeStopMonitoringSocketsIfNoActiveRequest();
+                }
+            }
+
+            private void storeDiscoveryManagerRequestMap(int clientId, int globalId,
+                    MdnsListener listener, ClientInfo clientInfo) {
+                clientInfo.mClientRequests.put(clientId,
+                        new DiscoveryManagerRequest(globalId, listener));
+                mIdToClientInfoMap.put(globalId, clientInfo);
             }
 
             private void clearRegisteredServiceInfo(ClientInfo clientInfo) {
@@ -579,7 +572,7 @@
 
                         final NsdServiceInfo info = args.serviceInfo;
                         id = getUniqueId();
-                        if (mMdnsDiscoveryManager != null) {
+                        if (mDeps.isMdnsDiscoveryManagerEnabled(mContext)) {
                             final String serviceType = constructServiceType(info.getServiceType());
                             if (serviceType == null) {
                                 clientInfo.onDiscoverServicesFailed(clientId,
@@ -597,7 +590,7 @@
                                     .build();
                             mMdnsDiscoveryManager.registerListener(
                                     listenServiceType, listener, options);
-                            storeListenerMap(clientId, id, listener, clientInfo);
+                            storeDiscoveryManagerRequestMap(clientId, id, listener, clientInfo);
                             clientInfo.onDiscoverServicesStarted(clientId, info);
                         } else {
                             maybeStartDaemon();
@@ -606,7 +599,7 @@
                                     Log.d(TAG, "Discover " + msg.arg2 + " " + id
                                             + info.getServiceType());
                                 }
-                                storeRequestMap(clientId, id, clientInfo, msg.what);
+                                storeLegacyRequestMap(clientId, id, clientInfo, msg.what);
                                 clientInfo.onDiscoverServicesStarted(clientId, info);
                             } else {
                                 stopServiceDiscovery(id);
@@ -616,7 +609,7 @@
                         }
                         break;
                     }
-                    case NsdManager.STOP_DISCOVERY:
+                    case NsdManager.STOP_DISCOVERY: {
                         if (DBG) Log.d(TAG, "Stop service discovery");
                         args = (ListenerArgs) msg.obj;
                         clientInfo = mClients.get(args.connector);
@@ -628,23 +621,21 @@
                             break;
                         }
 
-                        try {
-                            id = clientInfo.mClientIds.get(clientId);
-                        } catch (NullPointerException e) {
-                            clientInfo.onStopDiscoveryFailed(
-                                    clientId, NsdManager.FAILURE_INTERNAL_ERROR);
+                        final ClientRequest request = clientInfo.mClientRequests.get(clientId);
+                        if (request == null) {
+                            Log.e(TAG, "Unknown client request in STOP_DISCOVERY");
                             break;
                         }
-                        if (mMdnsDiscoveryManager != null) {
-                            final MdnsListener listener = clientInfo.mListeners.get(clientId);
-                            if (listener == null) {
-                                clientInfo.onStopDiscoveryFailed(
-                                        clientId, NsdManager.FAILURE_INTERNAL_ERROR);
-                                break;
-                            }
+                        id = request.mGlobalId;
+                        // Note isMdnsDiscoveryManagerEnabled may have changed to false at this
+                        // point, so this needs to check the type of the original request to
+                        // unregister instead of looking at the flag value.
+                        if (request instanceof DiscoveryManagerRequest) {
+                            final MdnsListener listener =
+                                    ((DiscoveryManagerRequest) request).mListener;
                             mMdnsDiscoveryManager.unregisterListener(
                                     listener.getListenedServiceType(), listener);
-                            removeListenerMap(clientId, id, clientInfo);
+                            removeRequestMap(clientId, id, clientInfo);
                             clientInfo.onStopDiscoverySucceeded(clientId);
                         } else {
                             removeRequestMap(clientId, id, clientInfo);
@@ -656,7 +647,8 @@
                             }
                         }
                         break;
-                    case NsdManager.REGISTER_SERVICE:
+                    }
+                    case NsdManager.REGISTER_SERVICE: {
                         if (DBG) Log.d(TAG, "Register service");
                         args = (ListenerArgs) msg.obj;
                         clientInfo = mClients.get(args.connector);
@@ -675,7 +667,7 @@
                         }
 
                         id = getUniqueId();
-                        if (mAdvertiser != null) {
+                        if (mDeps.isMdnsAdvertiserEnabled(mContext)) {
                             final NsdServiceInfo serviceInfo = args.serviceInfo;
                             final String serviceType = serviceInfo.getServiceType();
                             final String registerServiceType = constructServiceType(serviceType);
@@ -691,12 +683,12 @@
 
                             maybeStartMonitoringSockets();
                             mAdvertiser.addService(id, serviceInfo);
-                            storeRequestMap(clientId, id, clientInfo, msg.what);
+                            storeAdvertiserRequestMap(clientId, id, clientInfo);
                         } else {
                             maybeStartDaemon();
                             if (registerService(id, args.serviceInfo)) {
                                 if (DBG) Log.d(TAG, "Register " + clientId + " " + id);
-                                storeRequestMap(clientId, id, clientInfo, msg.what);
+                                storeLegacyRequestMap(clientId, id, clientInfo, msg.what);
                                 // Return success after mDns reports success
                             } else {
                                 unregisterService(id);
@@ -706,7 +698,8 @@
 
                         }
                         break;
-                    case NsdManager.UNREGISTER_SERVICE:
+                    }
+                    case NsdManager.UNREGISTER_SERVICE: {
                         if (DBG) Log.d(TAG, "unregister service");
                         args = (ListenerArgs) msg.obj;
                         clientInfo = mClients.get(args.connector);
@@ -717,10 +710,18 @@
                             Log.e(TAG, "Unknown connector in unregistration");
                             break;
                         }
-                        id = clientInfo.mClientIds.get(clientId);
+                        final ClientRequest request = clientInfo.mClientRequests.get(clientId);
+                        if (request == null) {
+                            Log.e(TAG, "Unknown client request in UNREGISTER_SERVICE");
+                            break;
+                        }
+                        id = request.mGlobalId;
                         removeRequestMap(clientId, id, clientInfo);
 
-                        if (mAdvertiser != null) {
+                        // Note isMdnsAdvertiserEnabled may have changed to false at this point,
+                        // so this needs to check the type of the original request to unregister
+                        // instead of looking at the flag value.
+                        if (request instanceof AdvertiserClientRequest) {
                             mAdvertiser.removeService(id);
                             clientInfo.onUnregisterServiceSucceeded(clientId);
                         } else {
@@ -732,6 +733,7 @@
                             }
                         }
                         break;
+                    }
                     case NsdManager.RESOLVE_SERVICE: {
                         if (DBG) Log.d(TAG, "Resolve service");
                         args = (ListenerArgs) msg.obj;
@@ -746,7 +748,7 @@
 
                         final NsdServiceInfo info = args.serviceInfo;
                         id = getUniqueId();
-                        if (mMdnsDiscoveryManager != null) {
+                        if (mDeps.isMdnsDiscoveryManagerEnabled(mContext)) {
                             final String serviceType = constructServiceType(info.getServiceType());
                             if (serviceType == null) {
                                 clientInfo.onResolveServiceFailed(clientId,
@@ -764,7 +766,7 @@
                                     .build();
                             mMdnsDiscoveryManager.registerListener(
                                     resolveServiceType, listener, options);
-                            storeListenerMap(clientId, id, listener, clientInfo);
+                            storeDiscoveryManagerRequestMap(clientId, id, listener, clientInfo);
                         } else {
                             if (clientInfo.mResolvedService != null) {
                                 clientInfo.onResolveServiceFailed(
@@ -775,7 +777,7 @@
                             maybeStartDaemon();
                             if (resolveService(id, args.serviceInfo)) {
                                 clientInfo.mResolvedService = new NsdServiceInfo();
-                                storeRequestMap(clientId, id, clientInfo, msg.what);
+                                storeLegacyRequestMap(clientId, id, clientInfo, msg.what);
                             } else {
                                 clientInfo.onResolveServiceFailed(
                                         clientId, NsdManager.FAILURE_INTERNAL_ERROR);
@@ -783,7 +785,7 @@
                         }
                         break;
                     }
-                    case NsdManager.STOP_RESOLUTION:
+                    case NsdManager.STOP_RESOLUTION: {
                         if (DBG) Log.d(TAG, "Stop service resolution");
                         args = (ListenerArgs) msg.obj;
                         clientInfo = mClients.get(args.connector);
@@ -795,7 +797,12 @@
                             break;
                         }
 
-                        id = clientInfo.mClientIds.get(clientId);
+                        final ClientRequest request = clientInfo.mClientRequests.get(clientId);
+                        if (request == null) {
+                            Log.e(TAG, "Unknown client request in STOP_RESOLUTION");
+                            break;
+                        }
+                        id = request.mGlobalId;
                         removeRequestMap(clientId, id, clientInfo);
                         if (stopResolveService(id)) {
                             clientInfo.onStopResolutionSucceeded(clientId);
@@ -806,6 +813,7 @@
                         clientInfo.mResolvedService = null;
                         // TODO: Implement the stop resolution with MdnsDiscoveryManager.
                         break;
+                    }
                     case NsdManager.REGISTER_SERVICE_CALLBACK:
                         if (DBG) Log.d(TAG, "Register a service callback");
                         args = (ListenerArgs) msg.obj;
@@ -829,13 +837,13 @@
                         if (resolveService(id, args.serviceInfo)) {
                             clientInfo.mRegisteredService = new NsdServiceInfo();
                             clientInfo.mClientIdForServiceUpdates = clientId;
-                            storeRequestMap(clientId, id, clientInfo, msg.what);
+                            storeLegacyRequestMap(clientId, id, clientInfo, msg.what);
                         } else {
                             clientInfo.onServiceInfoCallbackRegistrationFailed(
                                     clientId, NsdManager.FAILURE_BAD_PARAMETERS);
                         }
                         break;
-                    case NsdManager.UNREGISTER_SERVICE_CALLBACK:
+                    case NsdManager.UNREGISTER_SERVICE_CALLBACK: {
                         if (DBG) Log.d(TAG, "Unregister a service callback");
                         args = (ListenerArgs) msg.obj;
                         clientInfo = mClients.get(args.connector);
@@ -847,7 +855,12 @@
                             break;
                         }
 
-                        id = clientInfo.mClientIds.get(clientId);
+                        final ClientRequest request = clientInfo.mClientRequests.get(clientId);
+                        if (request == null) {
+                            Log.e(TAG, "Unknown client request in STOP_RESOLUTION");
+                            break;
+                        }
+                        id = request.mGlobalId;
                         removeRequestMap(clientId, id, clientInfo);
                         if (stopResolveService(id)) {
                             clientInfo.onServiceInfoCallbackUnregistered(clientId);
@@ -856,6 +869,7 @@
                         }
                         clearRegisteredServiceInfo(clientInfo);
                         break;
+                    }
                     case MDNS_SERVICE_EVENT:
                         if (!handleMDnsServiceEvent(msg.arg1, msg.arg2, msg.obj)) {
                             return NOT_HANDLED;
@@ -995,7 +1009,8 @@
 
                         final int id2 = getUniqueId();
                         if (getAddrInfo(id2, info.hostname, info.interfaceIdx)) {
-                            storeRequestMap(clientId, id2, clientInfo, NsdManager.RESOLVE_SERVICE);
+                            storeLegacyRequestMap(clientId, id2, clientInfo,
+                                    NsdManager.RESOLVE_SERVICE);
                         } else {
                             notifyResolveFailedResult(isListenedToUpdates, clientId, clientInfo,
                                     NsdManager.FAILURE_BAD_PARAMETERS);
@@ -1110,6 +1125,11 @@
                         clientInfo.onServiceLost(clientId, info);
                         break;
                     case NsdManager.RESOLVE_SERVICE_SUCCEEDED: {
+                        final ClientRequest request = clientInfo.mClientRequests.get(clientId);
+                        if (request == null) {
+                            Log.e(TAG, "Unknown client request in RESOLVE_SERVICE_SUCCEEDED");
+                            break;
+                        }
                         final MdnsServiceInfo serviceInfo = event.mMdnsServiceInfo;
                         // Add '.' in front of the service type that aligns with historical behavior
                         info.setServiceType("." + event.mRequestedServiceType);
@@ -1140,10 +1160,14 @@
                         }
 
                         // Unregister the listener immediately like IMDnsEventListener design
-                        final MdnsListener listener = clientInfo.mListeners.get(clientId);
+                        if (!(request instanceof DiscoveryManagerRequest)) {
+                            Log.wtf(TAG, "non-DiscoveryManager request in DiscoveryManager event");
+                            break;
+                        }
+                        final MdnsListener listener = ((DiscoveryManagerRequest) request).mListener;
                         mMdnsDiscoveryManager.unregisterListener(
                                 listener.getListenedServiceType(), listener);
-                        removeListenerMap(clientId, transactionId, clientInfo);
+                        removeRequestMap(clientId, transactionId, clientInfo);
                         break;
                     }
                     default:
@@ -1216,32 +1240,16 @@
         mNsdStateMachine.start();
         mMDnsManager = ctx.getSystemService(MDnsManager.class);
         mMDnsEventCallback = new MDnsEventCallback(mNsdStateMachine);
+        mDeps = deps;
 
-        final boolean discoveryManagerEnabled = deps.isMdnsDiscoveryManagerEnabled(ctx);
-        final boolean advertiserEnabled = deps.isMdnsAdvertiserEnabled(ctx);
-        if (discoveryManagerEnabled || advertiserEnabled) {
-            mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper());
-        } else {
-            mMdnsSocketProvider = null;
-        }
-
-        if (discoveryManagerEnabled) {
-            mMdnsSocketClient =
-                    new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider);
-            mMdnsDiscoveryManager =
-                    deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsSocketClient);
-            handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
-        } else {
-            mMdnsSocketClient = null;
-            mMdnsDiscoveryManager = null;
-        }
-
-        if (advertiserEnabled) {
-            mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider,
-                    new AdvertiserCallback());
-        } else {
-            mAdvertiser = null;
-        }
+        mMdnsSocketProvider = deps.makeMdnsSocketProvider(ctx, handler.getLooper());
+        mMdnsSocketClient =
+                new MdnsMultinetworkSocketClient(handler.getLooper(), mMdnsSocketProvider);
+        mMdnsDiscoveryManager =
+                deps.makeMdnsDiscoveryManager(new ExecutorProvider(), mMdnsSocketClient);
+        handler.post(() -> mMdnsSocketClient.setCallback(mMdnsDiscoveryManager));
+        mAdvertiser = deps.makeMdnsAdvertiser(handler.getLooper(), mMdnsSocketProvider,
+                new AdvertiserCallback());
     }
 
     /**
@@ -1604,6 +1612,39 @@
         mNsdStateMachine.dump(fd, pw, args);
     }
 
+    private abstract static class ClientRequest {
+        private final int mGlobalId;
+
+        private ClientRequest(int globalId) {
+            mGlobalId = globalId;
+        }
+    }
+
+    private static class LegacyClientRequest extends ClientRequest {
+        private final int mRequestCode;
+
+        private LegacyClientRequest(int globalId, int requestCode) {
+            super(globalId);
+            mRequestCode = requestCode;
+        }
+    }
+
+    private static class AdvertiserClientRequest extends ClientRequest {
+        private AdvertiserClientRequest(int globalId) {
+            super(globalId);
+        }
+    }
+
+    private static class DiscoveryManagerRequest extends ClientRequest {
+        @NonNull
+        private final MdnsListener mListener;
+
+        private DiscoveryManagerRequest(int globalId, @NonNull MdnsListener listener) {
+            super(globalId);
+            mListener = listener;
+        }
+    }
+
     /* Information tracked per client */
     private class ClientInfo {
 
@@ -1612,17 +1653,11 @@
         /* Remembers a resolved service until getaddrinfo completes */
         private NsdServiceInfo mResolvedService;
 
-        /* A map from client id to unique id sent to mDns */
-        private final SparseIntArray mClientIds = new SparseIntArray();
-
-        /* A map from client id to the type of the request we had received */
-        private final SparseIntArray mClientRequests = new SparseIntArray();
-
-        /* A map from client id to the MdnsListener */
-        private final SparseArray<MdnsListener> mListeners = new SparseArray<>();
+        /* A map from client-side ID (listenerKey) to the request */
+        private final SparseArray<ClientRequest> mClientRequests = new SparseArray<>();
 
         // The target SDK of this client < Build.VERSION_CODES.S
-        private boolean mIsLegacy = false;
+        private boolean mIsPreSClient = false;
 
         /*** The service that is registered to listen to its updates */
         private NsdServiceInfo mRegisteredService;
@@ -1638,38 +1673,59 @@
         public String toString() {
             StringBuilder sb = new StringBuilder();
             sb.append("mResolvedService ").append(mResolvedService).append("\n");
-            sb.append("mIsLegacy ").append(mIsLegacy).append("\n");
-            for(int i = 0; i< mClientIds.size(); i++) {
-                int clientID = mClientIds.keyAt(i);
-                sb.append("clientId ").append(clientID).
-                    append(" mDnsId ").append(mClientIds.valueAt(i)).
-                    append(" type ").append(mClientRequests.get(clientID)).append("\n");
+            sb.append("mIsLegacy ").append(mIsPreSClient).append("\n");
+            for (int i = 0; i < mClientRequests.size(); i++) {
+                int clientID = mClientRequests.keyAt(i);
+                sb.append("clientId ")
+                        .append(clientID)
+                        .append(" mDnsId ").append(mClientRequests.valueAt(i).mGlobalId)
+                        .append(" type ").append(
+                                mClientRequests.valueAt(i).getClass().getSimpleName())
+                        .append("\n");
             }
             return sb.toString();
         }
 
-        private boolean isLegacy() {
-            return mIsLegacy;
+        private boolean isPreSClient() {
+            return mIsPreSClient;
         }
 
-        private void setLegacy() {
-            mIsLegacy = true;
+        private void setPreSClient() {
+            mIsPreSClient = true;
         }
 
         // Remove any pending requests from the global map when we get rid of a client,
         // and send cancellations to the daemon.
         private void expungeAllRequests() {
-            int globalId, clientId, i;
             // TODO: to keep handler responsive, do not clean all requests for that client at once.
-            for (i = 0; i < mClientIds.size(); i++) {
-                clientId = mClientIds.keyAt(i);
-                globalId = mClientIds.valueAt(i);
+            for (int i = 0; i < mClientRequests.size(); i++) {
+                final int clientId = mClientRequests.keyAt(i);
+                final ClientRequest request = mClientRequests.valueAt(i);
+                final int globalId = request.mGlobalId;
                 mIdToClientInfoMap.remove(globalId);
                 if (DBG) {
                     Log.d(TAG, "Terminating client-ID " + clientId
                             + " global-ID " + globalId + " type " + mClientRequests.get(clientId));
                 }
-                switch (mClientRequests.get(clientId)) {
+
+                if (request instanceof DiscoveryManagerRequest) {
+                    final MdnsListener listener =
+                            ((DiscoveryManagerRequest) request).mListener;
+                    mMdnsDiscoveryManager.unregisterListener(
+                            listener.getListenedServiceType(), listener);
+                    continue;
+                }
+
+                if (request instanceof AdvertiserClientRequest) {
+                    mAdvertiser.removeService(globalId);
+                    continue;
+                }
+
+                if (!(request instanceof LegacyClientRequest)) {
+                    throw new IllegalStateException("Unknown request type: " + request.getClass());
+                }
+
+                switch (((LegacyClientRequest) request).mRequestCode) {
                     case NsdManager.DISCOVER_SERVICES:
                         stopServiceDiscovery(globalId);
                         break;
@@ -1677,37 +1733,25 @@
                         stopResolveService(globalId);
                         break;
                     case NsdManager.REGISTER_SERVICE:
-                        if (mAdvertiser != null) {
-                            mAdvertiser.removeService(globalId);
-                        } else {
-                            unregisterService(globalId);
-                        }
+                        unregisterService(globalId);
                         break;
                     default:
                         break;
                 }
             }
-            mClientIds.clear();
             mClientRequests.clear();
         }
 
-        void unregisterAllListeners() {
-            for (int i = 0; i < mListeners.size(); i++) {
-                final MdnsListener listener = mListeners.valueAt(i);
-                mMdnsDiscoveryManager.unregisterListener(
-                        listener.getListenedServiceType(), listener);
-            }
-            mListeners.clear();
-        }
-
-        // mClientIds is a sparse array of listener id -> mDnsClient id.  For a given mDnsClient id,
-        // return the corresponding listener id.  mDnsClient id is also called a global id.
+        // mClientRequests is a sparse array of listener id -> ClientRequest.  For a given
+        // mDnsClient id, return the corresponding listener id.  mDnsClient id is also called a
+        // global id.
         private int getClientId(final int globalId) {
-            int idx = mClientIds.indexOfValue(globalId);
-            if (idx < 0) {
-                return idx;
+            for (int i = 0; i < mClientRequests.size(); i++) {
+                if (mClientRequests.valueAt(i).mGlobalId == globalId) {
+                    return mClientRequests.keyAt(i);
+                }
             }
-            return mClientIds.keyAt(idx);
+            return -1;
         }
 
         private void maybeNotifyRegisteredServiceLost(@NonNull NsdServiceInfo info) {
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index cd4421b..a570ab1 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -107,6 +107,7 @@
 import android.Manifest;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.annotation.SuppressLint;
 import android.annotation.TargetApi;
 import android.app.AppOpsManager;
 import android.app.BroadcastOptions;
@@ -3038,6 +3039,8 @@
         sendStickyBroadcast(makeGeneralIntent(info, bcastType));
     }
 
+    // TODO(b/193460475): Remove when tooling supports SystemApi to public API.
+    @SuppressLint("NewApi")
     // TODO: Set the mini sdk to 31 and remove @TargetApi annotation when b/205923322 is addressed.
     @TargetApi(Build.VERSION_CODES.S)
     private void sendStickyBroadcast(Intent intent) {
@@ -5547,7 +5550,9 @@
                     break;
                 }
                 case NetworkAgent.CMD_MONITOR_AUTOMATIC_KEEPALIVE: {
-                    final AutomaticOnOffKeepalive ki = (AutomaticOnOffKeepalive) msg.obj;
+                    final AutomaticOnOffKeepalive ki =
+                            mKeepaliveTracker.getKeepaliveForBinder((IBinder) msg.obj);
+                    if (null == ki) return; // The callback was unregistered before the alarm fired
 
                     final Network network = ki.getNetwork();
                     boolean networkFound = false;
@@ -5575,10 +5580,14 @@
                 }
                 // Sent by KeepaliveTracker to process an app request on the state machine thread.
                 case NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE: {
-                    NetworkAgentInfo nai = getNetworkAgentInfoForNetwork((Network) msg.obj);
-                    int slot = msg.arg1;
-                    int reason = msg.arg2;
-                    mKeepaliveTracker.handleStopKeepalive(nai, slot, reason);
+                    final AutomaticOnOffKeepalive ki = mKeepaliveTracker.getKeepaliveForBinder(
+                            (IBinder) msg.obj);
+                    if (ki == null) {
+                        Log.e(TAG, "Attempt to stop an already stopped keepalive");
+                        return;
+                    }
+                    final int reason = msg.arg2;
+                    mKeepaliveTracker.handleStopKeepalive(ki, reason);
                     break;
                 }
                 case EVENT_REPORT_NETWORK_CONNECTIVITY: {
@@ -8512,6 +8521,8 @@
         // else not handled
     }
 
+    // TODO(b/193460475): Remove when tooling supports SystemApi to public API.
+    @SuppressLint("NewApi")
     private void sendIntent(PendingIntent pendingIntent, Intent intent) {
         mPendingIntentWakeLock.acquire();
         try {
@@ -9859,9 +9870,10 @@
     }
 
     @Override
-    public void stopKeepalive(Network network, int slot) {
+    public void stopKeepalive(@NonNull final ISocketKeepaliveCallback cb) {
         mHandler.sendMessage(mHandler.obtainMessage(
-                NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE, slot, SocketKeepalive.SUCCESS, network));
+                NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE, 0, SocketKeepalive.SUCCESS,
+                Objects.requireNonNull(cb).asBinder()));
     }
 
     @Override
diff --git a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
index b6627c6..8bfbcf7 100644
--- a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
@@ -45,6 +45,7 @@
 import android.net.Network;
 import android.net.NetworkAgent;
 import android.net.SocketKeepalive.InvalidSocketException;
+import android.os.Bundle;
 import android.os.FileUtils;
 import android.os.Handler;
 import android.os.IBinder;
@@ -53,6 +54,7 @@
 import android.os.SystemClock;
 import android.system.ErrnoException;
 import android.system.Os;
+import android.system.OsConstants;
 import android.system.StructTimeval;
 import android.util.Log;
 import android.util.SparseArray;
@@ -60,11 +62,13 @@
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.IndentingPrintWriter;
 import com.android.modules.utils.build.SdkLevel;
+import com.android.net.module.util.BinderUtils;
 import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.DeviceConfigUtils;
 import com.android.net.module.util.HexDump;
 import com.android.net.module.util.SocketUtils;
 import com.android.net.module.util.netlink.InetDiagMessage;
+import com.android.net.module.util.netlink.NetlinkMessage;
 import com.android.net.module.util.netlink.NetlinkUtils;
 import com.android.net.module.util.netlink.StructNlAttr;
 
@@ -92,8 +96,7 @@
     private static final int[] ADDRESS_FAMILIES = new int[] {AF_INET6, AF_INET};
     private static final String ACTION_TCP_POLLING_ALARM =
             "com.android.server.connectivity.KeepaliveTracker.TCP_POLLING_ALARM";
-    private static final String EXTRA_NETWORK = "network_id";
-    private static final String EXTRA_SLOT = "slot";
+    private static final String EXTRA_BINDER_TOKEN = "token";
     private static final long DEFAULT_TCP_POLLING_INTERVAL_MS = 120_000L;
     private static final String AUTOMATIC_ON_OFF_KEEPALIVE_VERSION =
             "automatic_on_off_keepalive_version";
@@ -159,11 +162,10 @@
         public void onReceive(Context context, Intent intent) {
             if (ACTION_TCP_POLLING_ALARM.equals(intent.getAction())) {
                 Log.d(TAG, "Received TCP polling intent");
-                final Network network = intent.getParcelableExtra(EXTRA_NETWORK);
-                final int slot = intent.getIntExtra(EXTRA_SLOT, -1);
+                final IBinder token = intent.getBundleExtra(EXTRA_BINDER_TOKEN).getBinder(
+                        EXTRA_BINDER_TOKEN);
                 mConnectivityServiceHandler.obtainMessage(
-                        NetworkAgent.CMD_MONITOR_AUTOMATIC_KEEPALIVE,
-                        slot, 0 , network).sendToTarget();
+                        NetworkAgent.CMD_MONITOR_AUTOMATIC_KEEPALIVE, token).sendToTarget();
             }
         }
     };
@@ -183,6 +185,8 @@
     public class AutomaticOnOffKeepalive {
         @NonNull
         private final KeepaliveTracker.KeepaliveInfo mKi;
+        @NonNull
+        private final ISocketKeepaliveCallback mCallback;
         @Nullable
         private final FileDescriptor mFd;
         @Nullable
@@ -193,6 +197,7 @@
         AutomaticOnOffKeepalive(@NonNull final KeepaliveTracker.KeepaliveInfo ki,
                 final boolean autoOnOff, @NonNull Context context) throws InvalidSocketException {
             this.mKi = Objects.requireNonNull(ki);
+            mCallback = ki.mCallback;
             if (autoOnOff && mDependencies.isFeatureEnabled(AUTOMATIC_ON_OFF_KEEPALIVE_VERSION)) {
                 mAutomaticOnOffState = STATE_ENABLED;
                 if (null == ki.mFd) {
@@ -205,8 +210,7 @@
                     Log.e(TAG, "Cannot dup fd: ", e);
                     throw new InvalidSocketException(ERROR_INVALID_SOCKET, e);
                 }
-                mTcpPollingAlarm = createTcpPollingAlarmIntent(
-                        context, ki.getNai().network(), ki.getSlot());
+                mTcpPollingAlarm = createTcpPollingAlarmIntent(context, mCallback.asBinder());
             } else {
                 mAutomaticOnOffState = STATE_ALWAYS_ON;
                 // A null fd is acceptable in KeepaliveInfo for backward compatibility of
@@ -226,12 +230,14 @@
         }
 
         private PendingIntent createTcpPollingAlarmIntent(@NonNull Context context,
-                @NonNull Network network, int slot) {
+                @NonNull IBinder token) {
             final Intent intent = new Intent(ACTION_TCP_POLLING_ALARM);
-            intent.putExtra(EXTRA_NETWORK, network);
-            intent.putExtra(EXTRA_SLOT, slot);
-            return PendingIntent.getBroadcast(
-                    context, 0 /* requestCode */, intent, PendingIntent.FLAG_IMMUTABLE);
+            // Intent doesn't expose methods to put extra Binders, but Bundle does.
+            final Bundle b = new Bundle();
+            b.putBinder(EXTRA_BINDER_TOKEN, token);
+            intent.putExtra(EXTRA_BINDER_TOKEN, b);
+            return BinderUtils.withCleanCallingIdentity(() -> PendingIntent.getBroadcast(
+                    context, 0 /* requestCode */, intent, PendingIntent.FLAG_IMMUTABLE));
         }
     }
 
@@ -318,33 +324,23 @@
             newKi = autoKi.mKi.withFd(autoKi.mFd);
         } catch (InvalidSocketException | IllegalArgumentException | SecurityException e) {
             Log.e(TAG, "Fail to construct keepalive", e);
-            mKeepaliveTracker.notifyErrorCallback(autoKi.mKi.mCallback, ERROR_INVALID_SOCKET);
+            mKeepaliveTracker.notifyErrorCallback(autoKi.mCallback, ERROR_INVALID_SOCKET);
             return;
         }
         autoKi.mAutomaticOnOffState = STATE_ENABLED;
         handleResumeKeepalive(newKi);
     }
 
-    private int findAutomaticOnOffKeepaliveIndex(@NonNull Network network, int slot) {
-        ensureRunningOnHandlerThread();
-
-        int index = 0;
-        for (AutomaticOnOffKeepalive ki : mAutomaticOnOffKeepalives) {
-            if (ki.match(network, slot)) {
-                return index;
-            }
-            index++;
-        }
-        return -1;
-    }
-
+    /**
+     * Find the AutomaticOnOffKeepalive associated with a given callback.
+     * @return the keepalive associated with this callback, or null if none
+     */
     @Nullable
-    private AutomaticOnOffKeepalive findAutomaticOnOffKeepalive(@NonNull Network network,
-            int slot) {
+    public AutomaticOnOffKeepalive getKeepaliveForBinder(@NonNull final IBinder token) {
         ensureRunningOnHandlerThread();
 
-        final int index = findAutomaticOnOffKeepaliveIndex(network, slot);
-        return (index >= 0) ? mAutomaticOnOffKeepalives.get(index) : null;
+        return CollectionUtils.findFirst(mAutomaticOnOffKeepalives,
+                it -> it.mCallback.asBinder().equals(token));
     }
 
     /**
@@ -397,17 +393,12 @@
     /**
      * Handle stop keepalives on the specific network with given slot.
      */
-    public void handleStopKeepalive(NetworkAgentInfo nai, int slot, int reason) {
-        final AutomaticOnOffKeepalive autoKi = findAutomaticOnOffKeepalive(nai.network, slot);
-        if (null == autoKi) {
-            Log.e(TAG, "Attempt to stop nonexistent keepalive " + slot + " on " + nai);
-            return;
-        }
-
+    public void handleStopKeepalive(@NonNull final AutomaticOnOffKeepalive autoKi, int reason) {
         // Stop the keepalive unless it was suspended. This includes the case where it's managed
         // but enabled, and the case where it's always on.
         if (autoKi.mAutomaticOnOffState != STATE_SUSPENDED) {
-            mKeepaliveTracker.handleStopKeepalive(nai, slot, reason);
+            final KeepaliveTracker.KeepaliveInfo ki = autoKi.mKi;
+            mKeepaliveTracker.handleStopKeepalive(ki.getNai(), ki.getSlot(), reason);
         }
 
         cleanupAutoOnOffKeepalive(autoKi);
@@ -581,6 +572,16 @@
                     bytes.position(startPos + SOCKDIAG_MSG_HEADER_SIZE);
 
                     if (isTargetTcpSocket(bytes, nlmsgLen, networkMark, networkMask)) {
+                        if (Log.isLoggable(TAG, Log.DEBUG)) {
+                            bytes.position(startPos);
+                            final InetDiagMessage diagMsg = (InetDiagMessage) NetlinkMessage.parse(
+                                    bytes, OsConstants.NETLINK_INET_DIAG);
+                            Log.d(TAG, String.format("Found open TCP connection by uid %d to %s"
+                                            + " cookie %d",
+                                    diagMsg.inetDiagMsg.idiag_uid,
+                                    diagMsg.inetDiagMsg.id.remSocketAddress,
+                                    diagMsg.inetDiagMsg.id.cookie));
+                        }
                         return true;
                     }
                 }
diff --git a/service/src/com/android/server/connectivity/KeepaliveTracker.java b/service/src/com/android/server/connectivity/KeepaliveTracker.java
index 63b76c7..a512b7c 100644
--- a/service/src/com/android/server/connectivity/KeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveTracker.java
@@ -125,8 +125,9 @@
      * which is only returned when the hardware has successfully started the keepalive.
      */
     class KeepaliveInfo implements IBinder.DeathRecipient {
-        // Bookkeeping data.
+        // TODO : remove this member. Only AutoOnOffKeepalive should have a reference to this.
         public final ISocketKeepaliveCallback mCallback;
+        // Bookkeeping data.
         private final int mUid;
         private final int mPid;
         private final boolean mPrivileged;
@@ -588,9 +589,9 @@
                 Log.d(TAG, "Started keepalive " + slot + " on " + nai.toShortString());
                 ki.mStartedState = KeepaliveInfo.STARTED;
                 try {
-                    ki.mCallback.onStarted(slot);
+                    ki.mCallback.onStarted();
                 } catch (RemoteException e) {
-                    Log.w(TAG, "Discarded onStarted(" + slot + ") callback");
+                    Log.w(TAG, "Discarded onStarted callback");
                 }
             } else {
                 Log.d(TAG, "Failed to start keepalive " + slot + " on " + nai.toShortString()
diff --git a/tests/common/AndroidTest_Coverage.xml b/tests/common/AndroidTest_Coverage.xml
index 48d26b8..c94ec27 100644
--- a/tests/common/AndroidTest_Coverage.xml
+++ b/tests/common/AndroidTest_Coverage.xml
@@ -13,7 +13,7 @@
      limitations under the License.
 -->
 <configuration description="Runs coverage tests for Connectivity">
-    <target_preparer class="com.android.tradefed.targetprep.TestAppInstallSetup">
+    <target_preparer class="com.android.tradefed.targetprep.suite.SuiteApkInstaller">
       <option name="test-file-name" value="ConnectivityCoverageTests.apk" />
       <option name="install-arg" value="-t" />
     </target_preparer>
diff --git a/tests/cts/net/native/dns/Android.bp b/tests/cts/net/native/dns/Android.bp
index 434e529..49b9337 100644
--- a/tests/cts/net/native/dns/Android.bp
+++ b/tests/cts/net/native/dns/Android.bp
@@ -24,6 +24,10 @@
         "liblog",
         "libutils",
     ],
+    static_libs: [
+        "libbase",
+        "libnetdutils",
+    ],
     // To be compatible with Q devices, the min_sdk_version must be 29.
     min_sdk_version: "29",
 }
diff --git a/tests/cts/net/native/dns/NativeDnsAsyncTest.cpp b/tests/cts/net/native/dns/NativeDnsAsyncTest.cpp
index e501475..68bd227 100644
--- a/tests/cts/net/native/dns/NativeDnsAsyncTest.cpp
+++ b/tests/cts/net/native/dns/NativeDnsAsyncTest.cpp
@@ -28,6 +28,7 @@
 
 #include <android/multinetwork.h>
 #include <gtest/gtest.h>
+#include <netdutils/NetNativeTestBase.h>
 
 namespace {
 constexpr int MAXPACKET = 8 * 1024;
@@ -101,7 +102,9 @@
 
 } // namespace
 
-TEST (NativeDnsAsyncTest, Async_Query) {
+class NativeDnsAsyncTest : public NetNativeTestBase {};
+
+TEST_F(NativeDnsAsyncTest, Async_Query) {
     // V4
     int fd1 = android_res_nquery(
             NETWORK_UNSPECIFIED, "www.google.com", ns_c_in, ns_t_a, 0);
@@ -123,7 +126,7 @@
     expectAnswersValid(fd1, AF_INET6, ns_r_noerror);
 }
 
-TEST (NativeDnsAsyncTest, Async_Send) {
+TEST_F(NativeDnsAsyncTest, Async_Send) {
     // V4
     uint8_t buf1[MAXPACKET] = {};
     int len1 = res_mkquery(ns_o_query, "www.googleapis.com",
@@ -162,7 +165,7 @@
     expectAnswersValid(fd1, AF_INET6, ns_r_noerror);
 }
 
-TEST (NativeDnsAsyncTest, Async_NXDOMAIN) {
+TEST_F(NativeDnsAsyncTest, Async_NXDOMAIN) {
     uint8_t buf[MAXPACKET] = {};
     int len = res_mkquery(ns_o_query, "test1-nx.metric.gstatic.com",
             ns_c_in, ns_t_a, nullptr, 0, nullptr, buf, sizeof(buf));
@@ -191,7 +194,7 @@
     expectAnswersValid(fd1, AF_INET6, ns_r_nxdomain);
 }
 
-TEST (NativeDnsAsyncTest, Async_Cancel) {
+TEST_F(NativeDnsAsyncTest, Async_Cancel) {
     int fd = android_res_nquery(
             NETWORK_UNSPECIFIED, "www.google.com", ns_c_in, ns_t_a, 0);
     errno = 0;
@@ -202,7 +205,7 @@
     // otherwise it will hit fdsan double-close fd.
 }
 
-TEST (NativeDnsAsyncTest, Async_Query_MALFORMED) {
+TEST_F(NativeDnsAsyncTest, Async_Query_MALFORMED) {
     // Empty string to create BLOB and query, we will get empty result and rcode = 0
     // on DNSTLS.
     int fd = android_res_nquery(
@@ -221,7 +224,7 @@
     EXPECT_EQ(-EMSGSIZE, fd);
 }
 
-TEST (NativeDnsAsyncTest, Async_Send_MALFORMED) {
+TEST_F(NativeDnsAsyncTest, Async_Send_MALFORMED) {
     uint8_t buf[10] = {};
     // empty BLOB
     int fd = android_res_nsend(NETWORK_UNSPECIFIED, buf, 10, 0);
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index d4b23a3..ccba983 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -1140,11 +1140,8 @@
                 .setPackage(mContext.getPackageName());
         // While ConnectivityService would put extra info such as network or request id before
         // broadcasting the inner intent. The MUTABLE flag needs to be added accordingly.
-        // TODO: replace with PendingIntent.FLAG_MUTABLE when this code compiles against S+ or
-        //  shims.
-        final int pendingIntentFlagMutable = 1 << 25;
         final PendingIntent pendingIntent = PendingIntent.getBroadcast(mContext, 0 /*requestCode*/,
-                intent, PendingIntent.FLAG_CANCEL_CURRENT | pendingIntentFlagMutable);
+                intent, PendingIntent.FLAG_CANCEL_CURRENT | PendingIntent.FLAG_MUTABLE);
 
         // We will register for a WIFI network being available or lost.
         mCm.registerNetworkCallback(makeWifiNetworkRequest(), pendingIntent);
@@ -1184,15 +1181,13 @@
         // Avoid receiving broadcasts from other runs by appending a timestamp
         final String broadcastAction = NETWORK_CALLBACK_ACTION + System.currentTimeMillis();
         try {
-            // TODO: replace with PendingIntent.FLAG_MUTABLE when this code compiles against S+
             // Intent is mutable to receive EXTRA_NETWORK_REQUEST from ConnectivityService
-            final int pendingIntentFlagMutable = 1 << 25;
             final String extraBoolKey = "extra_bool";
             firstIntent = PendingIntent.getBroadcast(mContext,
                     0 /* requestCode */,
                     new Intent(broadcastAction).putExtra(extraBoolKey, false)
                             .setPackage(mContext.getPackageName()),
-                    PendingIntent.FLAG_UPDATE_CURRENT | pendingIntentFlagMutable);
+                    PendingIntent.FLAG_UPDATE_CURRENT | PendingIntent.FLAG_MUTABLE);
 
             if (useListen) {
                 mCm.registerNetworkCallback(firstRequest, firstIntent);
@@ -1206,7 +1201,7 @@
                     0 /* requestCode */,
                     new Intent(broadcastAction).putExtra(extraBoolKey, true)
                             .setPackage(mContext.getPackageName()),
-                    PendingIntent.FLAG_UPDATE_CURRENT | pendingIntentFlagMutable);
+                    PendingIntent.FLAG_UPDATE_CURRENT | PendingIntent.FLAG_MUTABLE);
 
             // Because secondIntent.intentFilterEquals the first, the request should be replaced
             if (useListen) {
diff --git a/tests/cts/net/src/android/net/cts/Ikev2VpnTest.java b/tests/cts/net/src/android/net/cts/Ikev2VpnTest.java
index ac50740..6ba0fda 100644
--- a/tests/cts/net/src/android/net/cts/Ikev2VpnTest.java
+++ b/tests/cts/net/src/android/net/cts/Ikev2VpnTest.java
@@ -529,11 +529,10 @@
             assertFalse(profileState.isLockdownEnabled());
         }
 
-        cb.expectCapabilitiesThat(vpnNetwork, TIMEOUT_MS,
-                caps -> caps.hasTransport(TRANSPORT_VPN)
-                && caps.hasCapability(NET_CAPABILITY_INTERNET)
-                && !caps.hasCapability(NET_CAPABILITY_VALIDATED)
-                && Process.myUid() == caps.getOwnerUid());
+        cb.expectCaps(vpnNetwork, TIMEOUT_MS, c -> c.hasTransport(TRANSPORT_VPN)
+                && c.hasCapability(NET_CAPABILITY_INTERNET)
+                && !c.hasCapability(NET_CAPABILITY_VALIDATED)
+                && Process.myUid() == c.getOwnerUid());
         cb.expect(CallbackEntry.LINK_PROPERTIES_CHANGED, vpnNetwork);
         cb.expect(CallbackEntry.BLOCKED_STATUS, vpnNetwork);
 
diff --git a/tests/cts/net/src/android/net/cts/IpSecManagerTest.java b/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
index 8234ec1..4fa0080 100644
--- a/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/IpSecManagerTest.java
@@ -16,6 +16,7 @@
 
 package android.net.cts;
 
+import static android.Manifest.permission.NETWORK_SETTINGS;
 import static android.net.IpSecAlgorithm.AUTH_AES_CMAC;
 import static android.net.IpSecAlgorithm.AUTH_AES_XCBC;
 import static android.net.IpSecAlgorithm.AUTH_CRYPT_AES_GCM;
@@ -52,7 +53,9 @@
 
 import static com.android.compatibility.common.util.PropertyUtil.getFirstApiLevel;
 import static com.android.compatibility.common.util.PropertyUtil.getVendorApiLevel;
+import static com.android.testutils.DeviceInfoUtils.isKernelVersionAtLeast;
 import static com.android.testutils.MiscAsserts.assertThrows;
+import static com.android.testutils.TestPermissionUtil.runAsShell;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
@@ -62,6 +65,8 @@
 
 import android.net.IpSecAlgorithm;
 import android.net.IpSecManager;
+import android.net.IpSecManager.SecurityParameterIndex;
+import android.net.IpSecManager.UdpEncapsulationSocket;
 import android.net.IpSecTransform;
 import android.net.TrafficStats;
 import android.os.Build;
@@ -73,6 +78,7 @@
 import androidx.test.InstrumentationRegistry;
 import androidx.test.runner.AndroidJUnit4;
 
+import com.android.modules.utils.build.SdkLevel;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
 
@@ -120,7 +126,7 @@
     @Test
     public void testAllocSpi() throws Exception {
         for (InetAddress addr : GOOGLE_DNS_LIST) {
-            IpSecManager.SecurityParameterIndex randomSpi = null, droidSpi = null;
+            SecurityParameterIndex randomSpi, droidSpi;
             randomSpi = mISM.allocateSecurityParameterIndex(addr);
             assertTrue(
                     "Failed to receive a valid SPI",
@@ -258,6 +264,24 @@
         accepted.close();
     }
 
+    private IpSecTransform buildTransportModeTransform(
+            SecurityParameterIndex spi, InetAddress localAddr,
+            UdpEncapsulationSocket encapSocket)
+            throws Exception {
+        final IpSecTransform.Builder builder =
+                new IpSecTransform.Builder(InstrumentationRegistry.getContext())
+                        .setEncryption(new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY))
+                        .setAuthentication(
+                                new IpSecAlgorithm(
+                                        IpSecAlgorithm.AUTH_HMAC_SHA256,
+                                        AUTH_KEY,
+                                        AUTH_KEY.length * 8));
+        if (encapSocket != null) {
+            builder.setIpv4Encapsulation(encapSocket, encapSocket.getPort());
+        }
+        return builder.buildTransportModeTransform(localAddr, spi);
+    }
+
     /*
      * Alloc outbound SPI
      * Alloc inbound SPI
@@ -268,21 +292,8 @@
      * release transform
      * send data (expect exception)
      */
-    @Test
-    public void testCreateTransform() throws Exception {
-        InetAddress localAddr = InetAddress.getByName(IPV4_LOOPBACK);
-        IpSecManager.SecurityParameterIndex spi =
-                mISM.allocateSecurityParameterIndex(localAddr);
-
-        IpSecTransform transform =
-                new IpSecTransform.Builder(InstrumentationRegistry.getContext())
-                        .setEncryption(new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY))
-                        .setAuthentication(
-                                new IpSecAlgorithm(
-                                        IpSecAlgorithm.AUTH_HMAC_SHA256,
-                                        AUTH_KEY,
-                                        AUTH_KEY.length * 8))
-                        .buildTransportModeTransform(localAddr, spi);
+    private void doTestCreateTransform(String loopbackAddrString, boolean encap) throws Exception {
+        InetAddress localAddr = InetAddress.getByName(loopbackAddrString);
 
         final boolean [][] applyInApplyOut = {
                 {false, false}, {false, true}, {true, false}, {true,true}};
@@ -291,50 +302,93 @@
 
         byte[] in = new byte[data.length];
         DatagramPacket inPacket = new DatagramPacket(in, in.length);
-        DatagramSocket localSocket;
         int localPort;
 
         for(boolean[] io : applyInApplyOut) {
             boolean applyIn = io[0];
             boolean applyOut = io[1];
-            // Bind localSocket to a random available port.
-            localSocket = new DatagramSocket(0);
-            localPort = localSocket.getLocalPort();
-            localSocket.setSoTimeout(200);
-            outPacket.setPort(localPort);
-            if (applyIn) {
-                mISM.applyTransportModeTransform(
-                        localSocket, IpSecManager.DIRECTION_IN, transform);
-            }
-            if (applyOut) {
-                mISM.applyTransportModeTransform(
-                        localSocket, IpSecManager.DIRECTION_OUT, transform);
-            }
-            if (applyIn == applyOut) {
-                localSocket.send(outPacket);
-                localSocket.receive(inPacket);
-                assertTrue("Encapsulated data did not match.",
-                        Arrays.equals(outPacket.getData(), inPacket.getData()));
-                mISM.removeTransportModeTransforms(localSocket);
-                localSocket.close();
-            } else {
-                try {
+            try (
+                SecurityParameterIndex spi = mISM.allocateSecurityParameterIndex(localAddr);
+                UdpEncapsulationSocket encapSocket = encap
+                        ? getPrivilegedUdpEncapSocket(/*ipv6=*/ localAddr instanceof Inet6Address)
+                        : null;
+                IpSecTransform transform = buildTransportModeTransform(spi, localAddr,
+                        encapSocket);
+                // Bind localSocket to a random available port.
+                DatagramSocket localSocket = new DatagramSocket(0);
+            ) {
+                localPort = localSocket.getLocalPort();
+                localSocket.setSoTimeout(200);
+                outPacket.setPort(localPort);
+                if (applyIn) {
+                    mISM.applyTransportModeTransform(
+                            localSocket, IpSecManager.DIRECTION_IN, transform);
+                }
+                if (applyOut) {
+                    mISM.applyTransportModeTransform(
+                            localSocket, IpSecManager.DIRECTION_OUT, transform);
+                }
+                if (applyIn == applyOut) {
                     localSocket.send(outPacket);
                     localSocket.receive(inPacket);
-                } catch (IOException e) {
-                    continue;
-                } finally {
+                    assertTrue("Encrypted data did not match.",
+                            Arrays.equals(outPacket.getData(), inPacket.getData()));
                     mISM.removeTransportModeTransforms(localSocket);
-                    localSocket.close();
+                } else {
+                    try {
+                        localSocket.send(outPacket);
+                        localSocket.receive(inPacket);
+                    } catch (IOException e) {
+                        continue;
+                    } finally {
+                        mISM.removeTransportModeTransforms(localSocket);
+                    }
+                    // FIXME: This check is disabled because sockets currently receive data
+                    // if there is a valid SA for decryption, even when the input policy is
+                    // not applied to a socket.
+                    //  fail("Data IO should fail on asymmetrical transforms! + Input="
+                    //          + applyIn + " Output=" + applyOut);
                 }
-                // FIXME: This check is disabled because sockets currently receive data
-                // if there is a valid SA for decryption, even when the input policy is
-                // not applied to a socket.
-                //  fail("Data IO should fail on asymmetrical transforms! + Input="
-                //          + applyIn + " Output=" + applyOut);
             }
         }
-        transform.close();
+    }
+
+    private UdpEncapsulationSocket getPrivilegedUdpEncapSocket(boolean ipv6) throws Exception {
+        return runAsShell(NETWORK_SETTINGS, () -> {
+            if (ipv6) {
+                return mISM.openUdpEncapsulationSocket(65536);
+            } else {
+                // Can't pass 0 to IpSecManager#openUdpEncapsulationSocket(int).
+                return mISM.openUdpEncapsulationSocket();
+            }
+        });
+    }
+
+    private void assumeExperimentalIpv6UdpEncapSupported() throws Exception {
+        assumeTrue("Not supported before U", SdkLevel.isAtLeastU());
+        assumeTrue("Not supported by kernel", isKernelVersionAtLeast("5.15.31")
+                || (isKernelVersionAtLeast("5.10.108") && !isKernelVersionAtLeast("5.15.0")));
+    }
+
+    @Test
+    public void testCreateTransformIpv4() throws Exception {
+        doTestCreateTransform(IPV4_LOOPBACK, false);
+    }
+
+    @Test
+    public void testCreateTransformIpv6() throws Exception {
+        doTestCreateTransform(IPV6_LOOPBACK, false);
+    }
+
+    @Test
+    public void testCreateTransformIpv4Encap() throws Exception {
+        doTestCreateTransform(IPV4_LOOPBACK, true);
+    }
+
+    @Test
+    public void testCreateTransformIpv6Encap() throws Exception {
+        assumeExperimentalIpv6UdpEncapSupported();
+        doTestCreateTransform(IPV6_LOOPBACK, true);
     }
 
     /** Snapshot of TrafficStats as of initStatsChecker call for later comparisons */
@@ -503,8 +557,8 @@
         StatsChecker.initStatsChecker();
         InetAddress local = InetAddress.getByName(localAddress);
 
-        try (IpSecManager.UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket();
-                IpSecManager.SecurityParameterIndex spi =
+        try (UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket();
+                SecurityParameterIndex spi =
                         mISM.allocateSecurityParameterIndex(local)) {
 
             IpSecTransform.Builder transformBuilder =
@@ -656,7 +710,7 @@
     public void testIkeOverUdpEncapSocket() throws Exception {
         // IPv6 not supported for UDP-encap-ESP
         InetAddress local = InetAddress.getByName(IPV4_LOOPBACK);
-        try (IpSecManager.UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket()) {
+        try (UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket()) {
             NativeUdpSocket wrappedEncapSocket =
                     new NativeUdpSocket(encapSocket.getFileDescriptor());
             checkIkePacket(wrappedEncapSocket, local);
@@ -665,7 +719,7 @@
             IpSecAlgorithm crypt = new IpSecAlgorithm(IpSecAlgorithm.CRYPT_AES_CBC, CRYPT_KEY);
             IpSecAlgorithm auth = new IpSecAlgorithm(IpSecAlgorithm.AUTH_HMAC_MD5, getKey(128), 96);
 
-            try (IpSecManager.SecurityParameterIndex spi =
+            try (SecurityParameterIndex spi =
                             mISM.allocateSecurityParameterIndex(local);
                     IpSecTransform transform =
                             new IpSecTransform.Builder(InstrumentationRegistry.getContext())
@@ -1498,7 +1552,7 @@
 
     @Test
     public void testOpenUdpEncapSocketSpecificPort() throws Exception {
-        IpSecManager.UdpEncapsulationSocket encapSocket = null;
+        UdpEncapsulationSocket encapSocket = null;
         int port = -1;
         for (int i = 0; i < MAX_PORT_BIND_ATTEMPTS; i++) {
             try {
@@ -1527,7 +1581,7 @@
 
     @Test
     public void testOpenUdpEncapSocketRandomPort() throws Exception {
-        try (IpSecManager.UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket()) {
+        try (UdpEncapsulationSocket encapSocket = mISM.openUdpEncapsulationSocket()) {
             assertTrue("Returned invalid port", encapSocket.getPort() != 0);
         }
     }
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index 6df71c8..7ae4688 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -391,9 +391,7 @@
             val nc = NetworkCapabilities(agent.nc)
             nc.addCapability(NET_CAPABILITY_NOT_METERED)
             agent.sendNetworkCapabilities(nc)
-            callback.expectCapabilitiesThat(agent.network) {
-                it.hasCapability(NET_CAPABILITY_NOT_METERED)
-            }
+            callback.expectCaps(agent.network) { it.hasCapability(NET_CAPABILITY_NOT_METERED) }
             val networkInfo = mCM.getNetworkInfo(agent.network)
             assertEquals(subtypeUMTS, networkInfo.getSubtype())
             assertEquals(subtypeNameUMTS, networkInfo.getSubtypeName())
@@ -434,27 +432,28 @@
             (agent, callback) ->
             // Send signal strength and check that the callbacks are called appropriately.
             val nc = NetworkCapabilities(agent.nc)
+            val net = agent.network!!
             nc.setSignalStrength(20)
             agent.sendNetworkCapabilities(nc)
             callbacks.forEach { it.assertNoCallback(NO_CALLBACK_TIMEOUT) }
 
             nc.setSignalStrength(40)
             agent.sendNetworkCapabilities(nc)
-            callbacks[0].expectAvailableCallbacks(agent.network!!)
+            callbacks[0].expectAvailableCallbacks(net)
             callbacks[1].assertNoCallback(NO_CALLBACK_TIMEOUT)
             callbacks[2].assertNoCallback(NO_CALLBACK_TIMEOUT)
 
             nc.setSignalStrength(80)
             agent.sendNetworkCapabilities(nc)
-            callbacks[0].expectCapabilitiesThat(agent.network!!) { it.signalStrength == 80 }
-            callbacks[1].expectAvailableCallbacks(agent.network!!)
-            callbacks[2].expectAvailableCallbacks(agent.network!!)
+            callbacks[0].expectCaps(net) { it.signalStrength == 80 }
+            callbacks[1].expectAvailableCallbacks(net)
+            callbacks[2].expectAvailableCallbacks(net)
 
             nc.setSignalStrength(55)
             agent.sendNetworkCapabilities(nc)
-            callbacks[0].expectCapabilitiesThat(agent.network!!) { it.signalStrength == 55 }
-            callbacks[1].expectCapabilitiesThat(agent.network!!) { it.signalStrength == 55 }
-            callbacks[2].expect<Lost>(agent.network!!)
+            callbacks[0].expectCaps(net) { it.signalStrength == 55 }
+            callbacks[1].expectCaps(net) { it.signalStrength == 55 }
+            callbacks[2].expect<Lost>(net)
         }
         callbacks.forEach {
             mCM.unregisterNetworkCallback(it)
@@ -513,9 +512,7 @@
         val nc = NetworkCapabilities(agent.nc)
         nc.addCapability(NET_CAPABILITY_NOT_METERED)
         agent.sendNetworkCapabilities(nc)
-        callback.expectCapabilitiesThat(agent.network!!) {
-            it.hasCapability(NET_CAPABILITY_NOT_METERED)
-        }
+        callback.expectCaps(agent.network!!) { it.hasCapability(NET_CAPABILITY_NOT_METERED) }
     }
 
     private fun ncWithAllowedUids(vararg uids: Int) = NetworkCapabilities.Builder()
@@ -533,12 +530,12 @@
 
         // Make sure the UIDs have been ignored.
         callback.expect<Available>(agent.network!!)
-        callback.expectCapabilitiesThat(agent.network!!) {
+        callback.expectCaps(agent.network!!) {
             it.allowedUids.isEmpty() && !it.hasCapability(NET_CAPABILITY_VALIDATED)
         }
         callback.expect<LinkPropertiesChanged>(agent.network!!)
         callback.expect<BlockedStatus>(agent.network!!)
-        callback.expectCapabilitiesThat(agent.network!!) {
+        callback.expectCaps(agent.network!!) {
             it.allowedUids.isEmpty() && it.hasCapability(NET_CAPABILITY_VALIDATED)
         }
         callback.assertNoCallback(NO_CALLBACK_TIMEOUT)
@@ -582,8 +579,8 @@
         // tearDown() will unregister the requests and agents
     }
 
-    private fun hasAllTransports(nc: NetworkCapabilities?, transports: IntArray) =
-            nc != null && transports.all { nc.hasTransport(it) }
+    private fun NetworkCapabilities?.hasAllTransports(transports: IntArray) =
+            this != null && transports.all { hasTransport(it) }
 
     @Test
     @IgnoreUpTo(Build.VERSION_CODES.R)
@@ -625,25 +622,25 @@
         assertEquals(mySessionId, (vpnNc.transportInfo as VpnTransportInfo).sessionId)
 
         val testAndVpn = intArrayOf(TRANSPORT_TEST, TRANSPORT_VPN)
-        assertTrue(hasAllTransports(vpnNc, testAndVpn))
+        assertTrue(vpnNc.hasAllTransports(testAndVpn))
         assertFalse(vpnNc.hasCapability(NET_CAPABILITY_NOT_VPN))
-        assertTrue(hasAllTransports(vpnNc, defaultNetworkTransports),
+        assertTrue(vpnNc.hasAllTransports(defaultNetworkTransports),
                 "VPN transports ${Arrays.toString(vpnNc.transportTypes)}" +
                         " lacking transports from ${Arrays.toString(defaultNetworkTransports)}")
 
         // Check that when no underlying networks are announced the underlying transport disappears.
         agent.setUnderlyingNetworks(listOf<Network>())
-        callback.expectCapabilitiesThat(agent.network!!) {
-            it.transportTypes.size == 2 && hasAllTransports(it, testAndVpn)
+        callback.expectCaps(agent.network!!) {
+            it.transportTypes.size == 2 && it.hasAllTransports(testAndVpn)
         }
 
         // Put the underlying network back and check that the underlying transport reappears.
         val expectedTransports = (defaultNetworkTransports.toSet() + TRANSPORT_TEST + TRANSPORT_VPN)
                 .toIntArray()
         agent.setUnderlyingNetworks(null)
-        callback.expectCapabilitiesThat(agent.network!!) {
+        callback.expectCaps(agent.network!!) {
             it.transportTypes.size == expectedTransports.size &&
-                    hasAllTransports(it, expectedTransports)
+                    it.hasAllTransports(expectedTransports)
         }
 
         // Check that some underlying capabilities are propagated.
@@ -757,7 +754,7 @@
         val nc1 = NetworkCapabilities(agent.nc)
                 .addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
         agent.sendNetworkCapabilities(nc1)
-        callback.expectCapabilitiesThat(agent.network!!) {
+        callback.expectCaps(agent.network!!) {
             it.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
         }
 
@@ -765,7 +762,7 @@
         val nc2 = NetworkCapabilities(agent.nc)
                 .removeCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
         agent.sendNetworkCapabilities(nc2)
-        callback.expectCapabilitiesThat(agent.network!!) {
+        callback.expectCaps(agent.network!!) {
             !it.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
         }
 
@@ -917,12 +914,10 @@
         val history = ArrayTrackRecord<CallbackEntry>().newReadHead()
 
         sealed class CallbackEntry {
-            data class OnQosSessionAvailable(val sess: QosSession, val attr: QosSessionAttributes)
-                : CallbackEntry()
-            data class OnQosSessionLost(val sess: QosSession)
-                : CallbackEntry()
-            data class OnError(val ex: QosCallbackException)
-                : CallbackEntry()
+            data class OnQosSessionAvailable(val sess: QosSession, val attr: QosSessionAttributes) :
+                CallbackEntry()
+            data class OnQosSessionLost(val sess: QosSession) : CallbackEntry()
+            data class OnError(val ex: QosCallbackException) : CallbackEntry()
         }
 
         override fun onQosSessionAvailable(sess: QosSession, attr: QosSessionAttributes) {
@@ -1330,14 +1325,10 @@
 
         val (wifiAgent, wifiNetwork) = connectNetwork(TRANSPORT_WIFI)
         testCallback.expectAvailableCallbacks(wifiNetwork, validated = true)
-        testCallback.expectCapabilitiesThat(wifiNetwork) {
-            it.hasCapability(NET_CAPABILITY_VALIDATED)
-        }
+        testCallback.expectCaps(wifiNetwork) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
         matchAllCallback.expectAvailableCallbacks(wifiNetwork, validated = false)
         matchAllCallback.expect<Losing>(cellNetwork)
-        matchAllCallback.expectCapabilitiesThat(wifiNetwork) {
-            it.hasCapability(NET_CAPABILITY_VALIDATED)
-        }
+        matchAllCallback.expectCaps(wifiNetwork) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
 
         wifiAgent.unregisterAfterReplacement(5_000 /* timeoutMillis */)
         wifiAgent.expectCallback<OnNetworkDestroyed>()
diff --git a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
index b7eb009..9b27df5 100644
--- a/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
+++ b/tests/cts/net/src/android/net/cts/NsdManagerTest.kt
@@ -41,9 +41,14 @@
 import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.ServiceUnregistered
 import android.net.cts.NsdManagerTest.NsdRegistrationRecord.RegistrationEvent.UnregistrationFailed
 import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ResolveFailed
-import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ResolveStopped
+import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ResolutionStopped
 import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.ServiceResolved
 import android.net.cts.NsdManagerTest.NsdResolveRecord.ResolveEvent.StopResolutionFailed
+import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.RegisterCallbackFailed
+import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.ServiceUpdated
+import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.ServiceUpdatedLost
+import android.net.cts.NsdManagerTest.NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent.UnregisterCallbackSucceeded
+import android.net.cts.util.CtsNetUtils
 import android.net.nsd.NsdManager
 import android.net.nsd.NsdManager.DiscoveryListener
 import android.net.nsd.NsdManager.RegistrationListener
@@ -60,6 +65,7 @@
 import com.android.net.module.util.ArrayTrackRecord
 import com.android.net.module.util.TrackRecord
 import com.android.networkstack.apishim.NsdShimImpl
+import com.android.networkstack.apishim.common.NsdShim
 import com.android.testutils.ConnectivityModuleTest
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.TestableNetworkAgent
@@ -115,6 +121,7 @@
     private val serviceName = "NsdTest%09d".format(Random().nextInt(1_000_000_000))
     private val serviceType = "_nmt%09d._tcp".format(Random().nextInt(1_000_000_000))
     private val handlerThread = HandlerThread(NsdManagerTest::class.java.simpleName)
+    private val ctsNetUtils by lazy{ CtsNetUtils(context) }
 
     private lateinit var testNetwork1: TestTapNetwork
     private lateinit var testNetwork2: TestTapNetwork
@@ -157,7 +164,8 @@
 
         inline fun <reified V : NsdEvent> expectCallback(timeoutMs: Long = TIMEOUT_MS): V {
             val nextEvent = nextEvents.poll(timeoutMs)
-            assertNotNull(nextEvent, "No callback received after $timeoutMs ms")
+            assertNotNull(nextEvent, "No callback received after $timeoutMs ms, expected " +
+                    "${V::class.java.simpleName}")
             assertTrue(nextEvent is V, "Expected ${V::class.java.simpleName} but got " +
                     nextEvent.javaClass.simpleName)
             return nextEvent
@@ -265,7 +273,7 @@
                     ResolveEvent()
 
             data class ServiceResolved(val serviceInfo: NsdServiceInfo) : ResolveEvent()
-            data class ResolveStopped(val serviceInfo: NsdServiceInfo) : ResolveEvent()
+            data class ResolutionStopped(val serviceInfo: NsdServiceInfo) : ResolveEvent()
             data class StopResolutionFailed(val serviceInfo: NsdServiceInfo, val errorCode: Int) :
                     ResolveEvent()
         }
@@ -278,8 +286,8 @@
             add(ServiceResolved(si))
         }
 
-        override fun onResolveStopped(si: NsdServiceInfo) {
-            add(ResolveStopped(si))
+        override fun onResolutionStopped(si: NsdServiceInfo) {
+            add(ResolutionStopped(si))
         }
 
         override fun onStopResolutionFailed(si: NsdServiceInfo, err: Int) {
@@ -287,6 +295,32 @@
         }
     }
 
+    private class NsdServiceInfoCallbackRecord : NsdShim.ServiceInfoCallbackShim,
+            NsdRecord<NsdServiceInfoCallbackRecord.ServiceInfoCallbackEvent>() {
+        sealed class ServiceInfoCallbackEvent : NsdEvent {
+            data class RegisterCallbackFailed(val errorCode: Int) : ServiceInfoCallbackEvent()
+            data class ServiceUpdated(val serviceInfo: NsdServiceInfo) : ServiceInfoCallbackEvent()
+            object ServiceUpdatedLost : ServiceInfoCallbackEvent()
+            object UnregisterCallbackSucceeded : ServiceInfoCallbackEvent()
+        }
+
+        override fun onServiceInfoCallbackRegistrationFailed(err: Int) {
+            add(RegisterCallbackFailed(err))
+        }
+
+        override fun onServiceUpdated(si: NsdServiceInfo) {
+            add(ServiceUpdated(si))
+        }
+
+        override fun onServiceLost() {
+            add(ServiceUpdatedLost)
+        }
+
+        override fun onServiceInfoCallbackUnregistered() {
+            add(UnregisterCallbackSucceeded)
+        }
+    }
+
     @Before
     fun setUp() {
         handlerThread.start()
@@ -764,14 +798,72 @@
 
         val resolveRecord = NsdResolveRecord()
         // Try to resolve an unknown service then stop it immediately.
-        // Expected ResolveStopped callback.
+        // Expected ResolutionStopped callback.
         nsdShim.resolveService(nsdManager, si, { it.run() }, resolveRecord)
         nsdShim.stopServiceResolution(nsdManager, resolveRecord)
-        val stoppedCb = resolveRecord.expectCallback<ResolveStopped>()
+        val stoppedCb = resolveRecord.expectCallback<ResolutionStopped>()
         assertEquals(si.serviceName, stoppedCb.serviceInfo.serviceName)
         assertEquals(si.serviceType, stoppedCb.serviceInfo.serviceType)
     }
 
+    @Test
+    fun testRegisterServiceInfoCallback() {
+        // This test requires shims supporting U+ APIs (NsdManager.subscribeService)
+        assumeTrue(TestUtils.shouldTestUApis())
+
+        // Ensure Wi-Fi network connected and get addresses
+        val wifiNetwork = ctsNetUtils.ensureWifiConnected()
+        val lp = cm.getLinkProperties(wifiNetwork)
+        assertNotNull(lp)
+        val addresses = lp.addresses
+        assertFalse(addresses.isEmpty())
+
+        val si = NsdServiceInfo().apply {
+            serviceType = this@NsdManagerTest.serviceType
+            serviceName = this@NsdManagerTest.serviceName
+            network = wifiNetwork
+            port = 12345 // Test won't try to connect so port does not matter
+        }
+
+        // Register service on Wi-Fi network
+        val registrationRecord = NsdRegistrationRecord()
+        registerService(registrationRecord, si)
+
+        val discoveryRecord = NsdDiscoveryRecord()
+        val cbRecord = NsdServiceInfoCallbackRecord()
+        tryTest {
+            // Discover service on Wi-Fi network.
+            nsdShim.discoverServices(nsdManager, serviceType, NsdManager.PROTOCOL_DNS_SD,
+                    wifiNetwork, Executor { it.run() }, discoveryRecord)
+            val foundInfo = discoveryRecord.waitForServiceDiscovered(
+                    serviceName, wifiNetwork)
+
+            // Subscribe to service and check the addresses are the same as Wi-Fi addresses
+            nsdShim.registerServiceInfoCallback(nsdManager, foundInfo, { it.run() }, cbRecord)
+            for (i in addresses.indices) {
+                val subscribeCb = cbRecord.expectCallback<ServiceUpdated>()
+                assertEquals(foundInfo.serviceName, subscribeCb.serviceInfo.serviceName)
+                val hostAddresses = subscribeCb.serviceInfo.hostAddresses
+                assertEquals(i + 1, hostAddresses.size)
+                for (hostAddress in hostAddresses) {
+                    assertTrue(addresses.contains(hostAddress))
+                }
+            }
+        } cleanupStep {
+            nsdManager.unregisterService(registrationRecord)
+            registrationRecord.expectCallback<ServiceUnregistered>()
+            discoveryRecord.expectCallback<ServiceLost>()
+            cbRecord.expectCallback<ServiceUpdatedLost>()
+        } cleanupStep {
+            // Cancel subscription and check stop callback received.
+            nsdShim.unregisterServiceInfoCallback(nsdManager, cbRecord)
+            cbRecord.expectCallback<UnregisterCallbackSucceeded>()
+        } cleanup {
+            nsdManager.stopServiceDiscovery(discoveryRecord)
+            discoveryRecord.expectCallback<DiscoveryStopped>()
+        }
+    }
+
     /**
      * Register a service and return its registration record.
      */
diff --git a/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt b/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
index 26b058d..cf3f375 100644
--- a/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
+++ b/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
@@ -37,7 +37,6 @@
 import android.net.TestNetworkStackClient
 import android.net.Uri
 import android.net.metrics.IpConnectivityLog
-import com.android.server.connectivity.MultinetworkPolicyTracker
 import android.os.ConditionVariable
 import android.os.IBinder
 import android.os.SystemConfigManager
@@ -52,6 +51,7 @@
 import com.android.server.NetworkAgentWrapper
 import com.android.server.TestNetIdManager
 import com.android.server.connectivity.MockableSystemProperties
+import com.android.server.connectivity.MultinetworkPolicyTracker
 import com.android.server.connectivity.ProxyTracker
 import com.android.testutils.TestableNetworkCallback
 import org.junit.After
@@ -73,7 +73,6 @@
 import org.mockito.MockitoAnnotations
 import org.mockito.Spy
 import kotlin.test.assertEquals
-import kotlin.test.assertFalse
 import kotlin.test.assertNotNull
 import kotlin.test.assertTrue
 import kotlin.test.fail
@@ -297,7 +296,9 @@
         assertEquals(Uri.parse("https://login.capport.android.com"), capportData.userPortalUrl)
         assertEquals(Uri.parse("https://venueinfo.capport.android.com"), capportData.venueInfoUrl)
 
-        val nc = testCb.expectCapabilitiesWith(NET_CAPABILITY_CAPTIVE_PORTAL, na, TEST_TIMEOUT_MS)
-        assertFalse(nc.hasCapability(NET_CAPABILITY_VALIDATED))
+        testCb.expectCaps(na, TEST_TIMEOUT_MS) {
+            it.hasCapability(NET_CAPABILITY_CAPTIVE_PORTAL) &&
+                    !it.hasCapability(NET_CAPABILITY_VALIDATED)
+        }
     }
-}
\ No newline at end of file
+}
diff --git a/tests/unit/Android.bp b/tests/unit/Android.bp
index e0de246..8db307d 100644
--- a/tests/unit/Android.bp
+++ b/tests/unit/Android.bp
@@ -114,6 +114,7 @@
         "service-connectivity-pre-jarjar",
         "service-connectivity-tiramisu-pre-jarjar",
         "services.core-vpn",
+        "testables",
         "cts-net-utils"
     ],
     libs: [
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index a2d284b..41ed4ff 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -138,6 +138,7 @@
 import static android.net.OemNetworkPreferences.OEM_NETWORK_PREFERENCE_TEST;
 import static android.net.OemNetworkPreferences.OEM_NETWORK_PREFERENCE_TEST_ONLY;
 import static android.net.OemNetworkPreferences.OEM_NETWORK_PREFERENCE_UNINITIALIZED;
+import static android.net.Proxy.PROXY_CHANGE_ACTION;
 import static android.net.RouteInfo.RTN_UNREACHABLE;
 import static android.net.resolv.aidl.IDnsResolverUnsolicitedEventListener.PREFIX_OPERATION_ADDED;
 import static android.net.resolv.aidl.IDnsResolverUnsolicitedEventListener.PREFIX_OPERATION_REMOVED;
@@ -2273,22 +2274,15 @@
         }
     }
 
-    /** Expects that {@code count} CONNECTIVITY_ACTION broadcasts are received. */
-    private ExpectedBroadcast registerConnectivityBroadcast(final int count) {
-        return registerConnectivityBroadcastThat(count, intent -> true);
-    }
-
-    private ExpectedBroadcast registerConnectivityBroadcastThat(final int count,
+    private ExpectedBroadcast registerBroadcastReceiverThat(final String action, final int count,
             @NonNull final Predicate<Intent> filter) {
-        final IntentFilter intentFilter = new IntentFilter(CONNECTIVITY_ACTION);
+        final IntentFilter intentFilter = new IntentFilter(action);
         // AtomicReference allows receiver to access expected even though it is constructed later.
         final AtomicReference<ExpectedBroadcast> expectedRef = new AtomicReference<>();
         final BroadcastReceiver receiver = new BroadcastReceiver() {
             private int mRemaining = count;
             public void onReceive(Context context, Intent intent) {
-                final int type = intent.getIntExtra(EXTRA_NETWORK_TYPE, -1);
-                final NetworkInfo ni = intent.getParcelableExtra(EXTRA_NETWORK_INFO);
-                Log.d(TAG, "Received CONNECTIVITY_ACTION type=" + type + " ni=" + ni);
+                logIntent(intent);
                 if (!filter.test(intent)) return;
                 if (--mRemaining == 0) {
                     expectedRef.get().complete(intent);
@@ -2301,39 +2295,49 @@
         return expected;
     }
 
+    private void logIntent(Intent intent) {
+        final String action = intent.getAction();
+        if (CONNECTIVITY_ACTION.equals(action)) {
+            final int type = intent.getIntExtra(EXTRA_NETWORK_TYPE, -1);
+            final NetworkInfo ni = intent.getParcelableExtra(EXTRA_NETWORK_INFO);
+            Log.d(TAG, "Received " + action + ", type=" + type + " ni=" + ni);
+        } else if (PROXY_CHANGE_ACTION.equals(action)) {
+            final ProxyInfo proxy = (ProxyInfo) intent.getExtra(
+                    Proxy.EXTRA_PROXY_INFO, ProxyInfo.buildPacProxy(Uri.EMPTY));
+            Log.d(TAG, "Received " + action + ", proxy = " + proxy);
+        } else {
+            throw new IllegalArgumentException("Unsupported logging " + action);
+        }
+    }
+
+    /** Expects that {@code count} CONNECTIVITY_ACTION broadcasts are received. */
+    private ExpectedBroadcast expectConnectivityAction(final int count) {
+        return registerBroadcastReceiverThat(CONNECTIVITY_ACTION, count, intent -> true);
+    }
+
+    private ExpectedBroadcast expectConnectivityAction(int type, NetworkInfo.DetailedState state) {
+        return registerBroadcastReceiverThat(CONNECTIVITY_ACTION, 1, intent -> {
+            final int actualType = intent.getIntExtra(EXTRA_NETWORK_TYPE, -1);
+            final NetworkInfo ni = intent.getParcelableExtra(EXTRA_NETWORK_INFO);
+            return type == actualType
+                    && state == ni.getDetailedState()
+                    && extraInfoInBroadcastHasExpectedNullness(ni);
+        });
+    }
+
+    /** Expects that PROXY_CHANGE_ACTION broadcast is received. */
+    private ExpectedBroadcast expectProxyChangeAction() {
+        return registerBroadcastReceiverThat(PROXY_CHANGE_ACTION, 1, intent -> true);
+    }
+
     private ExpectedBroadcast expectProxyChangeAction(ProxyInfo proxy) {
-        return registerPacProxyBroadcastThat(intent -> {
+        return registerBroadcastReceiverThat(PROXY_CHANGE_ACTION, 1, intent -> {
             final ProxyInfo actualProxy = (ProxyInfo) intent.getExtra(Proxy.EXTRA_PROXY_INFO,
                     ProxyInfo.buildPacProxy(Uri.EMPTY));
             return proxy.equals(actualProxy);
         });
     }
 
-    private ExpectedBroadcast registerPacProxyBroadcast() {
-        return registerPacProxyBroadcastThat(intent -> true);
-    }
-
-    private ExpectedBroadcast registerPacProxyBroadcastThat(
-            @NonNull final Predicate<Intent> filter) {
-        final IntentFilter intentFilter = new IntentFilter(Proxy.PROXY_CHANGE_ACTION);
-        // AtomicReference allows receiver to access expected even though it is constructed later.
-        final AtomicReference<ExpectedBroadcast> expectedRef = new AtomicReference<>();
-        final BroadcastReceiver receiver = new BroadcastReceiver() {
-            public void onReceive(Context context, Intent intent) {
-                final ProxyInfo proxy = (ProxyInfo) intent.getExtra(
-                            Proxy.EXTRA_PROXY_INFO, ProxyInfo.buildPacProxy(Uri.EMPTY));
-                Log.d(TAG, "Receive PROXY_CHANGE_ACTION, proxy = " + proxy);
-                if (filter.test(intent)) {
-                    expectedRef.get().complete(intent);
-                }
-            }
-        };
-        final ExpectedBroadcast expected = new ExpectedBroadcast(receiver);
-        expectedRef.set(expected);
-        mServiceContext.registerReceiver(receiver, intentFilter);
-        return expected;
-    }
-
     private boolean extraInfoInBroadcastHasExpectedNullness(NetworkInfo ni) {
         final DetailedState state = ni.getDetailedState();
         if (state == DetailedState.CONNECTED && ni.getExtraInfo() == null) return false;
@@ -2349,16 +2353,6 @@
         return true;
     }
 
-    private ExpectedBroadcast expectConnectivityAction(int type, NetworkInfo.DetailedState state) {
-        return registerConnectivityBroadcastThat(1, intent -> {
-            final int actualType = intent.getIntExtra(EXTRA_NETWORK_TYPE, -1);
-            final NetworkInfo ni = intent.getParcelableExtra(EXTRA_NETWORK_INFO);
-            return type == actualType
-                    && state == ni.getDetailedState()
-                    && extraInfoInBroadcastHasExpectedNullness(ni);
-        });
-    }
-
     @Test
     public void testNetworkTypes() {
         // Ensure that our mocks for the networkAttributes config variable work as expected. If they
@@ -2393,7 +2387,7 @@
                 ConnectivityManager.REQUEST_ID_UNSET, NetworkRequest.Type.REQUEST);
 
         // File request, withdraw it and make sure no broadcast is sent
-        b = registerConnectivityBroadcast(1);
+        b = expectConnectivityAction(1);
         final TestNetworkCallback callback = new TestNetworkCallback();
         mCm.requestNetwork(legacyRequest, callback);
         callback.expect(AVAILABLE, mCellAgent);
@@ -2424,7 +2418,7 @@
         assertTrue(mCm.getAllNetworks()[0].equals(mWiFiAgent.getNetwork())
                 || mCm.getAllNetworks()[1].equals(mWiFiAgent.getNetwork()));
         // Test bringing up validated WiFi.
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         mWiFiAgent.connect(true);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_WIFI);
@@ -2441,7 +2435,7 @@
         assertLength(1, mCm.getAllNetworks());
         assertEquals(mCm.getAllNetworks()[0], mCm.getActiveNetwork());
         // Test WiFi disconnect.
-        b = registerConnectivityBroadcast(1);
+        b = expectConnectivityAction(1);
         mWiFiAgent.disconnect();
         b.expectBroadcast();
         verifyNoNetwork();
@@ -2607,7 +2601,7 @@
         mService.mCellularRadioTimesharingCapable = cellRadioTimesharingCapable;
         // Test bringing up unvalidated WiFi
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
-        ExpectedBroadcast b = registerConnectivityBroadcast(1);
+        ExpectedBroadcast b = expectConnectivityAction(1);
         mWiFiAgent.connect(false);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_WIFI);
@@ -2622,17 +2616,17 @@
         verifyActiveNetwork(TRANSPORT_WIFI);
         // Test bringing up validated cellular
         mCellAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR);
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         mCellAgent.connect(true);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_CELLULAR);
         // Test cellular disconnect.
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         mCellAgent.disconnect();
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_WIFI);
         // Test WiFi disconnect.
-        b = registerConnectivityBroadcast(1);
+        b = expectConnectivityAction(1);
         mWiFiAgent.disconnect();
         b.expectBroadcast();
         verifyNoNetwork();
@@ -2655,23 +2649,23 @@
         mService.mCellularRadioTimesharingCapable = cellRadioTimesharingCapable;
         // Test bringing up unvalidated cellular.
         mCellAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR);
-        ExpectedBroadcast b = registerConnectivityBroadcast(1);
+        ExpectedBroadcast b = expectConnectivityAction(1);
         mCellAgent.connect(false);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_CELLULAR);
         // Test bringing up unvalidated WiFi.
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         mWiFiAgent.connect(false);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_WIFI);
         // Test WiFi disconnect.
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         mWiFiAgent.disconnect();
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_CELLULAR);
         // Test cellular disconnect.
-        b = registerConnectivityBroadcast(1);
+        b = expectConnectivityAction(1);
         mCellAgent.disconnect();
         b.expectBroadcast();
         verifyNoNetwork();
@@ -2694,7 +2688,7 @@
         mService.mCellularRadioTimesharingCapable = cellRadioTimesharingCapable;
         // Test bringing up unvalidated WiFi.
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
-        ExpectedBroadcast b = registerConnectivityBroadcast(1);
+        ExpectedBroadcast b = expectConnectivityAction(1);
         mWiFiAgent.connect(false);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_WIFI);
@@ -2702,14 +2696,14 @@
                 NET_CAPABILITY_VALIDATED));
         // Test bringing up validated cellular.
         mCellAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR);
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         mCellAgent.connect(true);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_CELLULAR);
         assertFalse(mCm.getNetworkCapabilities(mWiFiAgent.getNetwork()).hasCapability(
                 NET_CAPABILITY_VALIDATED));
         // Test cellular disconnect.
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         mCellAgent.disconnect();
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_WIFI);
@@ -2771,7 +2765,7 @@
         if (expectLingering) {
             generalCb.expectLosing(net1);
         }
-        generalCb.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, net2);
+        generalCb.expectCaps(net2, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         defaultCb.expectAvailableDoubleValidatedCallbacks(net2);
 
         // Make sure cell 1 is unwanted immediately if the radio can't time share, but only
@@ -2849,23 +2843,23 @@
         mService.mCellularRadioTimesharingCapable = cellRadioTimesharingCapable;
         // Test bringing up validated cellular.
         mCellAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR);
-        ExpectedBroadcast b = registerConnectivityBroadcast(1);
+        ExpectedBroadcast b = expectConnectivityAction(1);
         mCellAgent.connect(true);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_CELLULAR);
         // Test bringing up validated WiFi.
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         mWiFiAgent.connect(true);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_WIFI);
         // Test WiFi getting really weak.
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         mWiFiAgent.adjustScore(-11);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_CELLULAR);
         // Test WiFi restoring signal strength.
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         mWiFiAgent.adjustScore(11);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_WIFI);
@@ -2930,18 +2924,18 @@
         mService.mCellularRadioTimesharingCapable = cellRadioTimesharingCapable;
         // Test bringing up validated cellular.
         mCellAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR);
-        ExpectedBroadcast b = registerConnectivityBroadcast(1);
+        ExpectedBroadcast b = expectConnectivityAction(1);
         mCellAgent.connect(true);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_CELLULAR);
         // Test bringing up validated WiFi.
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         mWiFiAgent.connect(true);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_WIFI);
         // Reevaluate WiFi (it'll instantly fail DNS).
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         assertTrue(mCm.getNetworkCapabilities(mWiFiAgent.getNetwork()).hasCapability(
                 NET_CAPABILITY_VALIDATED));
         mCm.reportBadNetwork(mWiFiAgent.getNetwork());
@@ -2951,7 +2945,7 @@
                 NET_CAPABILITY_VALIDATED));
         verifyActiveNetwork(TRANSPORT_CELLULAR);
         // Reevaluate cellular (it'll instantly fail DNS).
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         assertTrue(mCm.getNetworkCapabilities(mCellAgent.getNetwork()).hasCapability(
                 NET_CAPABILITY_VALIDATED));
         mCm.reportBadNetwork(mCellAgent.getNetwork());
@@ -2981,18 +2975,18 @@
         mService.mCellularRadioTimesharingCapable = cellRadioTimesharingCapable;
         // Test bringing up unvalidated WiFi.
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
-        ExpectedBroadcast b = registerConnectivityBroadcast(1);
+        ExpectedBroadcast b = expectConnectivityAction(1);
         mWiFiAgent.connect(false);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_WIFI);
         // Test bringing up validated cellular.
         mCellAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR);
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         mCellAgent.connect(true);
         b.expectBroadcast();
         verifyActiveNetwork(TRANSPORT_CELLULAR);
         // Reevaluate cellular (it'll instantly fail DNS).
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         assertTrue(mCm.getNetworkCapabilities(mCellAgent.getNetwork()).hasCapability(
                 NET_CAPABILITY_VALIDATED));
         mCm.reportBadNetwork(mCellAgent.getNetwork());
@@ -3070,9 +3064,8 @@
             NetworkSpecifier specifier, TestNetworkCallback ... callbacks) {
         for (TestNetworkCallback c : callbacks) {
             c.expect(AVAILABLE, network);
-            c.expectCapabilitiesThat(network, (nc) ->
-                    !nc.hasCapability(NET_CAPABILITY_VALIDATED)
-                            && Objects.equals(specifier, nc.getNetworkSpecifier()));
+            c.expectCaps(network, cb -> !cb.hasCapability(NET_CAPABILITY_VALIDATED)
+                    && Objects.equals(specifier, cb.getNetworkSpecifier()));
             c.expect(LINK_PROPERTIES_CHANGED, network);
             c.expect(BLOCKED_STATUS, network);
         }
@@ -3131,7 +3124,7 @@
         mCm.registerNetworkCallback(cellRequest, cellNetworkCallback);
 
         // Test unvalidated networks
-        ExpectedBroadcast b = registerConnectivityBroadcast(1);
+        ExpectedBroadcast b = expectConnectivityAction(1);
         mCellAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR);
         mCellAgent.connect(false);
         genericNetworkCallback.expectAvailableCallbacksUnvalidated(mCellAgent);
@@ -3146,7 +3139,7 @@
         assertNoCallbacks(genericNetworkCallback, wifiNetworkCallback, cellNetworkCallback);
         assertEquals(mCellAgent.getNetwork(), mCm.getActiveNetwork());
 
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
         mWiFiAgent.connect(false);
         genericNetworkCallback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
@@ -3155,18 +3148,18 @@
         b.expectBroadcast();
         assertNoCallbacks(genericNetworkCallback, wifiNetworkCallback, cellNetworkCallback);
 
-        b = registerConnectivityBroadcast(2);
+        b = expectConnectivityAction(2);
         mWiFiAgent.disconnect();
-        genericNetworkCallback.expect(LOST, mWiFiAgent);
-        wifiNetworkCallback.expect(LOST, mWiFiAgent);
+        genericNetworkCallback.expect(CallbackEntry.LOST, mWiFiAgent);
+        wifiNetworkCallback.expect(CallbackEntry.LOST, mWiFiAgent);
         cellNetworkCallback.assertNoCallback();
         b.expectBroadcast();
         assertNoCallbacks(genericNetworkCallback, wifiNetworkCallback, cellNetworkCallback);
 
-        b = registerConnectivityBroadcast(1);
+        b = expectConnectivityAction(1);
         mCellAgent.disconnect();
-        genericNetworkCallback.expect(LOST, mCellAgent);
-        cellNetworkCallback.expect(LOST, mCellAgent);
+        genericNetworkCallback.expect(CallbackEntry.LOST, mCellAgent);
+        cellNetworkCallback.expect(CallbackEntry.LOST, mCellAgent);
         b.expectBroadcast();
         assertNoCallbacks(genericNetworkCallback, wifiNetworkCallback, cellNetworkCallback);
 
@@ -3188,7 +3181,8 @@
         mWiFiAgent.connect(true);
         genericNetworkCallback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         genericNetworkCallback.expectLosing(mCellAgent);
-        genericNetworkCallback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        genericNetworkCallback.expectCaps(mWiFiAgent,
+                c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         wifiNetworkCallback.expectAvailableThenValidatedCallbacks(mWiFiAgent);
         cellNetworkCallback.expectLosing(mCellAgent);
         assertEquals(mWiFiAgent.getNetwork(), mCm.getActiveNetwork());
@@ -3337,7 +3331,7 @@
         callback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         // TODO: Investigate sending validated before losing.
         callback.expectLosing(mCellAgent);
-        callback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        callback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         defaultCallback.expectAvailableDoubleValidatedCallbacks(mWiFiAgent);
         assertEquals(mWiFiAgent.getNetwork(), mCm.getActiveNetwork());
         assertEquals(defaultCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
@@ -3346,7 +3340,7 @@
         callback.expectAvailableCallbacksUnvalidated(mEthernetAgent);
         // TODO: Investigate sending validated before losing.
         callback.expectLosing(mWiFiAgent);
-        callback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mEthernetAgent);
+        callback.expectCaps(mEthernetAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         defaultCallback.expectAvailableDoubleValidatedCallbacks(mEthernetAgent);
         assertEquals(mEthernetAgent.getNetwork(), mCm.getActiveNetwork());
         assertEquals(defaultCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
@@ -3381,7 +3375,7 @@
         // if the network is still up.
         mWiFiAgent.removeCapability(NET_CAPABILITY_NOT_METERED);
         // We expect a notification about the capabilities change, and nothing else.
-        defaultCallback.expectCapabilitiesWithout(NET_CAPABILITY_NOT_METERED, mWiFiAgent);
+        defaultCallback.expectCaps(mWiFiAgent, c -> !c.hasCapability(NET_CAPABILITY_NOT_METERED));
         defaultCallback.assertNoCallback();
         callback.expect(LOST, mWiFiAgent);
         assertEquals(defaultCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
@@ -3440,7 +3434,7 @@
         callback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         // TODO: Investigate sending validated before losing.
         callback.expectLosing(mCellAgent);
-        callback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        callback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         defaultCallback.expectAvailableThenValidatedCallbacks(mWiFiAgent);
         assertEquals(mWiFiAgent.getNetwork(), mCm.getActiveNetwork());
         assertEquals(defaultCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
@@ -3467,7 +3461,7 @@
         callback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         // TODO: Investigate sending validated before losing.
         callback.expectLosing(mCellAgent);
-        callback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        callback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         assertEquals(defaultCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
 
         NetworkRequest cellRequest = new NetworkRequest.Builder()
@@ -3517,7 +3511,7 @@
         mEthernetAgent.connect(true);
         callback.expectAvailableCallbacksUnvalidated(mEthernetAgent);
         callback.expectLosing(mWiFiAgent);
-        callback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mEthernetAgent);
+        callback.expectCaps(mEthernetAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         trackDefaultCallback.expectAvailableDoubleValidatedCallbacks(mEthernetAgent);
         defaultCallback.expectAvailableDoubleValidatedCallbacks(mEthernetAgent);
         assertEquals(defaultCallback.getLastAvailableNetwork(), mCm.getActiveNetwork());
@@ -3577,7 +3571,7 @@
         defaultCallback.expectAvailableDoubleValidatedCallbacks(mWiFiAgent);
         callback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         callback.expectLosing(mCellAgent);
-        callback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        callback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
 
         // File a request for cellular, then release it.
         NetworkRequest cellRequest = new NetworkRequest.Builder()
@@ -3590,7 +3584,8 @@
         // Let linger run its course.
         callback.assertNoCallback();
         final int lingerTimeoutMs = TEST_LINGER_DELAY_MS + TEST_LINGER_DELAY_MS / 4;
-        callback.expectCapabilitiesWithout(NET_CAPABILITY_FOREGROUND, mCellAgent, lingerTimeoutMs);
+        callback.expectCaps(mCellAgent, lingerTimeoutMs,
+                c -> !c.hasCapability(NET_CAPABILITY_FOREGROUND));
 
         // Clean up.
         mCm.unregisterNetworkCallback(defaultCallback);
@@ -3812,7 +3807,7 @@
         mWiFiAgent.connect(true);
         callback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         callback.expectLosing(mCellAgent);
-        callback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        callback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         assertEquals(mWiFiAgent.getNetwork(), mCm.getActiveNetwork());
         expectUnvalidationCheckWillNotNotify(mWiFiAgent);
 
@@ -3820,7 +3815,7 @@
         mEthernetAgent.connect(true);
         callback.expectAvailableCallbacksUnvalidated(mEthernetAgent);
         callback.expectLosing(mWiFiAgent);
-        callback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mEthernetAgent);
+        callback.expectCaps(mEthernetAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         assertEquals(mEthernetAgent.getNetwork(), mCm.getActiveNetwork());
         callback.assertNoCallback();
 
@@ -4268,7 +4263,7 @@
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
         mWiFiAgent.connectWithPartialConnectivity();
         callback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
-        callback.expectCapabilitiesWith(NET_CAPABILITY_PARTIAL_CONNECTIVITY, mWiFiAgent);
+        callback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_PARTIAL_CONNECTIVITY));
 
         // Mobile data should be the default network.
         assertEquals(mCellAgent.getNetwork(), mCm.getActiveNetwork());
@@ -4296,8 +4291,8 @@
         // validated.
         mCm.reportNetworkConnectivity(mWiFiAgent.getNetwork(), true);
         callback.expectLosing(mCellAgent);
-        NetworkCapabilities nc = callback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED,
-                mWiFiAgent);
+        NetworkCapabilities nc =
+                callback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         assertTrue(nc.hasCapability(NET_CAPABILITY_PARTIAL_CONNECTIVITY));
         assertEquals(mWiFiAgent.getNetwork(), mCm.getActiveNetwork());
 
@@ -4311,7 +4306,7 @@
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
         mWiFiAgent.connectWithPartialConnectivity();
         callback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
-        callback.expectCapabilitiesWith(NET_CAPABILITY_PARTIAL_CONNECTIVITY, mWiFiAgent);
+        callback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_PARTIAL_CONNECTIVITY));
 
         // Mobile data should be the default network.
         assertEquals(mCellAgent.getNetwork(), mCm.getActiveNetwork());
@@ -4343,7 +4338,7 @@
         callback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         verify(mWiFiAgent.mNetworkMonitor, times(1)).setAcceptPartialConnectivity();
         callback.expectLosing(mCellAgent);
-        nc = callback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        nc = callback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         assertFalse(nc.hasCapability(NET_CAPABILITY_PARTIAL_CONNECTIVITY));
 
         // Wifi should be the default network.
@@ -4364,7 +4359,7 @@
         verify(mWiFiAgent.mNetworkMonitor, times(1)).setAcceptPartialConnectivity();
         callback.expectLosing(mCellAgent);
         assertEquals(mWiFiAgent.getNetwork(), mCm.getActiveNetwork());
-        callback.expectCapabilitiesWith(NET_CAPABILITY_PARTIAL_CONNECTIVITY, mWiFiAgent);
+        callback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_PARTIAL_CONNECTIVITY));
         expectUnvalidationCheckWillNotNotify(mWiFiAgent);
 
         mWiFiAgent.setNetworkValid(false /* privateDnsProbeSent */);
@@ -4372,7 +4367,7 @@
         // Need a trigger point to let NetworkMonitor tell ConnectivityService that the network is
         // validated.
         mCm.reportNetworkConnectivity(mWiFiAgent.getNetwork(), true);
-        callback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        callback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         mWiFiAgent.disconnect();
         callback.expect(LOST, mWiFiAgent);
 
@@ -4388,8 +4383,8 @@
         callback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         verify(mWiFiAgent.mNetworkMonitor, times(1)).setAcceptPartialConnectivity();
         callback.expectLosing(mCellAgent);
-        callback.expectCapabilitiesWith(
-                NET_CAPABILITY_PARTIAL_CONNECTIVITY | NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        callback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_PARTIAL_CONNECTIVITY)
+                && c.hasCapability(NET_CAPABILITY_VALIDATED));
         expectUnvalidationCheckWillNotNotify(mWiFiAgent);
         mWiFiAgent.disconnect();
         callback.expect(LOST, mWiFiAgent);
@@ -4420,7 +4415,7 @@
         // This is necessary because of b/245893397, the same bug that happens where we use
         // expectAvailableDoubleValidatedCallbacks.
         // TODO : fix b/245893397 and remove this.
-        wifiCallback.expectCapabilitiesWith(NET_CAPABILITY_CAPTIVE_PORTAL, mWiFiAgent);
+        wifiCallback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_CAPTIVE_PORTAL));
 
         // Check that startCaptivePortalApp sends the expected command to NetworkMonitor.
         mCm.startCaptivePortalApp(mWiFiAgent.getNetwork());
@@ -4431,9 +4426,9 @@
         mWiFiAgent.setNetworkPartial();
         mCm.reportNetworkConnectivity(mWiFiAgent.getNetwork(), true);
         waitForIdle();
-        wifiCallback.expectCapabilitiesThat(mWiFiAgent, nc ->
-                nc.hasCapability(NET_CAPABILITY_PARTIAL_CONNECTIVITY)
-                        && !nc.hasCapability(NET_CAPABILITY_CAPTIVE_PORTAL));
+        wifiCallback.expectCaps(mWiFiAgent,
+                c -> c.hasCapability(NET_CAPABILITY_PARTIAL_CONNECTIVITY)
+                        && !c.hasCapability(NET_CAPABILITY_CAPTIVE_PORTAL));
 
         // Report partial connectivity is accepted.
         mWiFiAgent.setNetworkPartialValid(false /* privateDnsProbeSent */);
@@ -4441,9 +4436,10 @@
                 false /* always */);
         waitForIdle();
         mCm.reportNetworkConnectivity(mWiFiAgent.getNetwork(), true);
-        wifiCallback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        wifiCallback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         validatedCallback.expectAvailableCallbacksValidated(mWiFiAgent);
-        validatedCallback.expectCapabilitiesWith(NET_CAPABILITY_PARTIAL_CONNECTIVITY, mWiFiAgent);
+        validatedCallback.expectCaps(mWiFiAgent,
+                c -> c.hasCapability(NET_CAPABILITY_PARTIAL_CONNECTIVITY));
 
         mCm.unregisterNetworkCallback(wifiCallback);
         mCm.unregisterNetworkCallback(validatedCallback);
@@ -4552,7 +4548,7 @@
         // This is necessary because of b/245893397, the same bug that happens where we use
         // expectAvailableDoubleValidatedCallbacks.
         // TODO : fix b/245893397 and remove this.
-        captivePortalCallback.expect(NETWORK_CAPS_UPDATED, mWiFiAgent);
+        captivePortalCallback.expectCaps(mWiFiAgent);
 
         startCaptivePortalApp(mWiFiAgent);
 
@@ -5239,7 +5235,8 @@
 
         // Suspend the network.
         mCellAgent.suspend();
-        cellNetworkCallback.expectCapabilitiesWithout(NET_CAPABILITY_NOT_SUSPENDED, mCellAgent);
+        cellNetworkCallback.expectCaps(mCellAgent,
+                c -> !c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
         cellNetworkCallback.expect(SUSPENDED, mCellAgent);
         cellNetworkCallback.assertNoCallback();
         assertEquals(NetworkInfo.State.SUSPENDED, mCm.getActiveNetworkInfo().getState());
@@ -5254,7 +5251,8 @@
         mCm.unregisterNetworkCallback(dfltNetworkCallback);
 
         mCellAgent.resume();
-        cellNetworkCallback.expectCapabilitiesWith(NET_CAPABILITY_NOT_SUSPENDED, mCellAgent);
+        cellNetworkCallback.expectCaps(mCellAgent,
+                c -> c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
         cellNetworkCallback.expect(RESUMED, mCellAgent);
         cellNetworkCallback.assertNoCallback();
         assertEquals(NetworkInfo.State.CONNECTED, mCm.getActiveNetworkInfo().getState());
@@ -5481,10 +5479,10 @@
         // When wifi connects, cell lingers.
         callback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         callback.expectLosing(mCellAgent);
-        callback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        callback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         fgCallback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         fgCallback.expectLosing(mCellAgent);
-        fgCallback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        fgCallback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         assertTrue(isForegroundNetwork(mCellAgent));
         assertTrue(isForegroundNetwork(mWiFiAgent));
 
@@ -5493,7 +5491,7 @@
         int timeoutMs = TEST_LINGER_DELAY_MS + TEST_LINGER_DELAY_MS / 4;
         fgCallback.expect(LOST, mCellAgent, timeoutMs);
         // Expect a network capabilities update sans FOREGROUND.
-        callback.expectCapabilitiesWithout(NET_CAPABILITY_FOREGROUND, mCellAgent);
+        callback.expectCaps(mCellAgent, c -> !c.hasCapability(NET_CAPABILITY_FOREGROUND));
         assertFalse(isForegroundNetwork(mCellAgent));
         assertTrue(isForegroundNetwork(mWiFiAgent));
 
@@ -5506,8 +5504,8 @@
         fgCallback.expectAvailableCallbacksValidated(mCellAgent);
         // Expect a network capabilities update with FOREGROUND, because the most recent
         // request causes its state to change.
-        cellCallback.expectCapabilitiesWith(NET_CAPABILITY_FOREGROUND, mCellAgent);
-        callback.expectCapabilitiesWith(NET_CAPABILITY_FOREGROUND, mCellAgent);
+        cellCallback.expectCaps(mCellAgent, c -> c.hasCapability(NET_CAPABILITY_FOREGROUND));
+        callback.expectCaps(mCellAgent, c -> c.hasCapability(NET_CAPABILITY_FOREGROUND));
         assertTrue(isForegroundNetwork(mCellAgent));
         assertTrue(isForegroundNetwork(mWiFiAgent));
 
@@ -5516,7 +5514,7 @@
         mCm.unregisterNetworkCallback(cellCallback);
         fgCallback.expect(LOST, mCellAgent);
         // Expect a network capabilities update sans FOREGROUND.
-        callback.expectCapabilitiesWithout(NET_CAPABILITY_FOREGROUND, mCellAgent);
+        callback.expectCaps(mCellAgent, c -> !c.hasCapability(NET_CAPABILITY_FOREGROUND));
         assertFalse(isForegroundNetwork(mCellAgent));
         assertTrue(isForegroundNetwork(mWiFiAgent));
 
@@ -5668,7 +5666,8 @@
             // Need a trigger point to let NetworkMonitor tell ConnectivityService that network is
             // validated – see testPartialConnectivity.
             mCm.reportNetworkConnectivity(mCellAgent.getNetwork(), true);
-            cellNetworkCallback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mCellAgent);
+            cellNetworkCallback.expectCaps(mCellAgent,
+                    c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
             testFactory.expectRequestRemove();
             testFactory.assertRequestCountEquals(0);
             // Accordingly, the factory shouldn't be started.
@@ -5869,14 +5868,15 @@
         mWiFiAgent.setNetworkValid(true /* privateDnsProbeSent */);
         // Have CS reconsider the network (see testPartialConnectivity)
         mCm.reportNetworkConnectivity(mWiFiAgent.getNetwork(), true);
-        wifiNetworkCallback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        wifiNetworkCallback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         cellCallback.expectOnNetworkUnneeded(defaultCaps);
         wifiCallback.assertNoCallback();
 
         // Wifi is no longer validated. Cell is needed again.
         mWiFiAgent.setNetworkInvalid(true /* invalidBecauseOfPrivateDns */);
         mCm.reportNetworkConnectivity(mWiFiAgent.getNetwork(), false);
-        wifiNetworkCallback.expectCapabilitiesWithout(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        wifiNetworkCallback.expectCaps(mWiFiAgent,
+                c -> !c.hasCapability(NET_CAPABILITY_VALIDATED));
         cellCallback.expectOnNetworkNeeded(defaultCaps);
         wifiCallback.assertNoCallback();
 
@@ -5898,7 +5898,8 @@
         wifiCallback.assertNoCallback();
         mWiFiAgent.setNetworkValid(true /* privateDnsProbeSent */);
         mCm.reportNetworkConnectivity(mWiFiAgent.getNetwork(), true);
-        wifiNetworkCallback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        wifiNetworkCallback.expectCaps(mWiFiAgent,
+                c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         cellCallback.expectOnNetworkUnneeded(defaultCaps);
         wifiCallback.assertNoCallback();
 
@@ -5906,7 +5907,8 @@
         // not needed.
         mWiFiAgent.setNetworkInvalid(true /* invalidBecauseOfPrivateDns */);
         mCm.reportNetworkConnectivity(mWiFiAgent.getNetwork(), false);
-        wifiNetworkCallback.expectCapabilitiesWithout(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        wifiNetworkCallback.expectCaps(mWiFiAgent,
+                c -> !c.hasCapability(NET_CAPABILITY_VALIDATED));
         cellCallback.assertNoCallback();
         wifiCallback.assertNoCallback();
     }
@@ -6001,7 +6003,7 @@
         // Fail validation on wifi.
         mWiFiAgent.setNetworkInvalid(false /* invalidBecauseOfPrivateDns */);
         mCm.reportNetworkConnectivity(wifiNetwork, false);
-        defaultCallback.expectCapabilitiesWithout(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        defaultCallback.expectCaps(mWiFiAgent, c -> !c.hasCapability(NET_CAPABILITY_VALIDATED));
         validatedWifiCallback.expect(LOST, mWiFiAgent);
         expectNotification(mWiFiAgent, NotificationType.LOST_INTERNET);
 
@@ -6052,7 +6054,7 @@
         // Fail validation on wifi and expect the dialog to appear.
         mWiFiAgent.setNetworkInvalid(false /* invalidBecauseOfPrivateDns */);
         mCm.reportNetworkConnectivity(wifiNetwork, false);
-        defaultCallback.expectCapabilitiesWithout(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        defaultCallback.expectCaps(mWiFiAgent, c -> !c.hasCapability(NET_CAPABILITY_VALIDATED));
         validatedWifiCallback.expect(LOST, mWiFiAgent);
         expectNotification(mWiFiAgent, NotificationType.LOST_INTERNET);
 
@@ -6998,7 +7000,7 @@
         assertNotPinnedToWifi();
 
         // Disconnect cell and wifi.
-        ExpectedBroadcast b = registerConnectivityBroadcast(3);  // cell down, wifi up, wifi down.
+        ExpectedBroadcast b = expectConnectivityAction(3);  // cell down, wifi up, wifi down.
         mCellAgent.disconnect();
         mWiFiAgent.disconnect();
         b.expectBroadcast();
@@ -7011,7 +7013,7 @@
         assertPinnedToWifiWithWifiDefault();
 
         // ... and is maintained even when that network is no longer the default.
-        b = registerConnectivityBroadcast(1);
+        b = expectConnectivityAction(1);
         mCellAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
         mCellAgent.connect(true);
         b.expectBroadcast();
@@ -7188,7 +7190,7 @@
 
     @Test
     public void testNetworkInfoOfTypeNone() throws Exception {
-        ExpectedBroadcast b = registerConnectivityBroadcast(1);
+        ExpectedBroadcast b = expectConnectivityAction(1);
 
         verifyNoNetwork();
         TestNetworkAgentWrapper wifiAware = new TestNetworkAgentWrapper(TRANSPORT_WIFI_AWARE);
@@ -7269,7 +7271,7 @@
         CallbackEntry.LinkPropertiesChanged cbi =
                 networkCallback.expect(LINK_PROPERTIES_CHANGED, networkAgent);
         networkCallback.expect(BLOCKED_STATUS, networkAgent);
-        networkCallback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, networkAgent);
+        networkCallback.expectCaps(networkAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         networkCallback.assertNoCallback();
         checkDirectlyConnectedRoutes(cbi.getLp(), asList(myIpv4Address),
                 asList(myIpv4DefaultRoute));
@@ -7583,8 +7585,7 @@
         TestNetworkCallback callback = new TestNetworkCallback();
         mCm.registerDefaultNetworkCallback(callback);
         callback.expect(AVAILABLE, mCellAgent);
-        callback.expectCapabilitiesThat(
-                mCellAgent, nc -> Arrays.equals(adminUids, nc.getAdministratorUids()));
+        callback.expectCaps(mCellAgent, c -> Arrays.equals(adminUids, c.getAdministratorUids()));
         mCm.unregisterNetworkCallback(callback);
 
         // Verify case where caller does NOT have permission
@@ -7594,7 +7595,7 @@
         callback = new TestNetworkCallback();
         mCm.registerDefaultNetworkCallback(callback);
         callback.expect(AVAILABLE, mCellAgent);
-        callback.expectCapabilitiesThat(mCellAgent, nc -> nc.getAdministratorUids().length == 0);
+        callback.expectCaps(mCellAgent, c -> c.getAdministratorUids().length == 0);
     }
 
     @Test
@@ -8177,8 +8178,7 @@
             mMockVpn.setUnderlyingNetworks(new Network[]{wifiNetwork});
             // onCapabilitiesChanged() should be called because
             // NetworkCapabilities#mUnderlyingNetworks is updated.
-            CallbackEntry ce = callback.expect(NETWORK_CAPS_UPDATED, mMockVpn);
-            final NetworkCapabilities vpnNc1 = ((CallbackEntry.CapabilitiesChanged) ce).getCaps();
+            final NetworkCapabilities vpnNc1 = callback.expectCaps(mMockVpn);
             // Since the wifi network hasn't brought up,
             // ConnectivityService#applyUnderlyingCapabilities cannot find it. Update
             // NetworkCapabilities#mUnderlyingNetworks to an empty array, and it will be updated to
@@ -8213,8 +8213,7 @@
             // 2. When a network connects, updateNetworkInfo propagates underlying network
             //    capabilities before rematching networks.
             // Given that this scenario can't really happen, this is probably fine for now.
-            ce = callback.expect(NETWORK_CAPS_UPDATED, mMockVpn);
-            final NetworkCapabilities vpnNc2 = ((CallbackEntry.CapabilitiesChanged) ce).getCaps();
+            final NetworkCapabilities vpnNc2 = callback.expectCaps(mMockVpn);
             // The wifi network is brought up, NetworkCapabilities#mUnderlyingNetworks is updated to
             // it.
             underlyingNetwork.add(wifiNetwork);
@@ -8228,8 +8227,8 @@
             // Disconnect the network, and expect to see the VPN capabilities change accordingly.
             mWiFiAgent.disconnect();
             callback.expect(LOST, mWiFiAgent);
-            callback.expectCapabilitiesThat(mMockVpn, (nc) ->
-                    nc.getTransportTypes().length == 1 && nc.hasTransport(TRANSPORT_VPN));
+            callback.expectCaps(mMockVpn, c -> c.getTransportTypes().length == 1
+                            && c.hasTransport(TRANSPORT_VPN));
 
             mMockVpn.disconnect();
             mCm.unregisterNetworkCallback(callback);
@@ -8255,9 +8254,8 @@
         // Connect cellular data.
         mCellAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR);
         mCellAgent.connect(false /* validated */);
-        callback.expectCapabilitiesThat(mMockVpn,
-                nc -> nc.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
-                        && nc.hasTransport(TRANSPORT_CELLULAR));
+        callback.expectCaps(mMockVpn, c -> c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
+                && c.hasTransport(TRANSPORT_CELLULAR));
         callback.assertNoCallback();
 
         assertTrue(mCm.getNetworkCapabilities(mMockVpn.getNetwork())
@@ -8270,9 +8268,8 @@
 
         // Suspend the cellular network and expect the VPN to be suspended.
         mCellAgent.suspend();
-        callback.expectCapabilitiesThat(mMockVpn,
-                nc -> !nc.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
-                        && nc.hasTransport(TRANSPORT_CELLULAR));
+        callback.expectCaps(mMockVpn, c -> !c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
+                && c.hasTransport(TRANSPORT_CELLULAR));
         callback.expect(SUSPENDED, mMockVpn);
         callback.assertNoCallback();
 
@@ -8288,9 +8285,8 @@
         // Switch to another network. The VPN should no longer be suspended.
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
         mWiFiAgent.connect(false /* validated */);
-        callback.expectCapabilitiesThat(mMockVpn,
-                nc -> nc.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
-                        && nc.hasTransport(TRANSPORT_WIFI));
+        callback.expectCaps(mMockVpn, c -> c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
+                && c.hasTransport(TRANSPORT_WIFI));
         callback.expect(RESUMED, mMockVpn);
         callback.assertNoCallback();
 
@@ -8306,13 +8302,11 @@
         mCellAgent.resume();
         callback.assertNoCallback();
         mWiFiAgent.disconnect();
-        callback.expectCapabilitiesThat(mMockVpn,
-                nc -> nc.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
-                        && nc.hasTransport(TRANSPORT_CELLULAR));
+        callback.expectCaps(mMockVpn, c -> c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
+                && c.hasTransport(TRANSPORT_CELLULAR));
         // Spurious double callback?
-        callback.expectCapabilitiesThat(mMockVpn,
-                nc -> nc.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
-                        && nc.hasTransport(TRANSPORT_CELLULAR));
+        callback.expectCaps(mMockVpn, c -> c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
+                && c.hasTransport(TRANSPORT_CELLULAR));
         callback.assertNoCallback();
 
         assertTrue(mCm.getNetworkCapabilities(mMockVpn.getNetwork())
@@ -8325,9 +8319,8 @@
 
         // Suspend cellular and expect no connectivity.
         mCellAgent.suspend();
-        callback.expectCapabilitiesThat(mMockVpn,
-                nc -> !nc.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
-                        && nc.hasTransport(TRANSPORT_CELLULAR));
+        callback.expectCaps(mMockVpn, c -> !c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
+                && c.hasTransport(TRANSPORT_CELLULAR));
         callback.expect(SUSPENDED, mMockVpn);
         callback.assertNoCallback();
 
@@ -8341,9 +8334,8 @@
 
         // Resume cellular and expect that connectivity comes back.
         mCellAgent.resume();
-        callback.expectCapabilitiesThat(mMockVpn,
-                nc -> nc.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
-                        && nc.hasTransport(TRANSPORT_CELLULAR));
+        callback.expectCaps(mMockVpn, c -> c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED)
+                && c.hasTransport(TRANSPORT_CELLULAR));
         callback.expect(RESUMED, mMockVpn);
         callback.assertNoCallback();
 
@@ -8432,7 +8424,7 @@
         // can't currently update their UIDs without disconnecting, so this does not matter too
         // much, but that is the reason the test here has to check for an update to the
         // capabilities instead of the expected LOST then AVAILABLE.
-        defaultCallback.expect(NETWORK_CAPS_UPDATED, mMockVpn);
+        defaultCallback.expectCaps(mMockVpn);
         systemDefaultCallback.assertNoCallback();
 
         ranges.add(new UidRange(uid, uid));
@@ -8444,7 +8436,7 @@
         vpnNetworkCallback.expectAvailableCallbacksValidated(mMockVpn);
         // TODO : Here like above, AVAILABLE would be correct, but because this can't actually
         // happen outside of the test, ConnectivityService does not rematch callbacks.
-        defaultCallback.expect(NETWORK_CAPS_UPDATED, mMockVpn);
+        defaultCallback.expectCaps(mMockVpn);
         systemDefaultCallback.assertNoCallback();
 
         mWiFiAgent.disconnect();
@@ -8565,7 +8557,7 @@
         mMockVpn.getAgent().mNetworkMonitor.forceReevaluation(Process.myUid());
         // Expect to see the validated capability, but no other changes, because the VPN is already
         // the default network for the app.
-        callback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mMockVpn);
+        callback.expectCaps(mMockVpn, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         callback.assertNoCallback();
 
         mMockVpn.disconnect();
@@ -8597,8 +8589,8 @@
 
         vpnNetworkCallback.expectAvailableCallbacks(mMockVpn.getNetwork(),
                 false /* suspended */, false /* validated */, false /* blocked */, TIMEOUT_MS);
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn.getNetwork(), TIMEOUT_MS,
-                nc -> nc.hasCapability(NET_CAPABILITY_VALIDATED));
+        vpnNetworkCallback.expectCaps(mMockVpn.getNetwork(), TIMEOUT_MS,
+                c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
 
         final NetworkCapabilities nc = mCm.getNetworkCapabilities(mMockVpn.getNetwork());
         assertTrue(nc.hasTransport(TRANSPORT_VPN));
@@ -8660,11 +8652,12 @@
 
         mMockVpn.setUnderlyingNetworks(new Network[] { mCellAgent.getNetwork() });
 
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn,
-                (caps) -> caps.hasTransport(TRANSPORT_VPN)
-                && caps.hasTransport(TRANSPORT_CELLULAR) && !caps.hasTransport(TRANSPORT_WIFI)
-                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED)
-                && caps.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
+        vpnNetworkCallback.expectCaps(mMockVpn,
+                c -> c.hasTransport(TRANSPORT_VPN)
+                        && c.hasTransport(TRANSPORT_CELLULAR)
+                        && !c.hasTransport(TRANSPORT_WIFI)
+                        && !c.hasCapability(NET_CAPABILITY_NOT_METERED)
+                        && c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
         assertDefaultNetworkCapabilities(userId, mCellAgent);
 
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
@@ -8675,62 +8668,68 @@
         mMockVpn.setUnderlyingNetworks(
                 new Network[] { mCellAgent.getNetwork(), mWiFiAgent.getNetwork() });
 
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn,
-                (caps) -> caps.hasTransport(TRANSPORT_VPN)
-                && caps.hasTransport(TRANSPORT_CELLULAR) && caps.hasTransport(TRANSPORT_WIFI)
-                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED)
-                && caps.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
+        vpnNetworkCallback.expectCaps(mMockVpn,
+                c -> c.hasTransport(TRANSPORT_VPN)
+                        && c.hasTransport(TRANSPORT_CELLULAR)
+                        && c.hasTransport(TRANSPORT_WIFI)
+                        && !c.hasCapability(NET_CAPABILITY_NOT_METERED)
+                        && c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
         assertDefaultNetworkCapabilities(userId, mCellAgent, mWiFiAgent);
 
         // Don't disconnect, but note the VPN is not using wifi any more.
         mMockVpn.setUnderlyingNetworks(new Network[] { mCellAgent.getNetwork() });
 
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn,
-                (caps) -> caps.hasTransport(TRANSPORT_VPN)
-                && caps.hasTransport(TRANSPORT_CELLULAR) && !caps.hasTransport(TRANSPORT_WIFI)
-                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED)
-                && caps.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
+        vpnNetworkCallback.expectCaps(mMockVpn,
+                c -> c.hasTransport(TRANSPORT_VPN)
+                        && c.hasTransport(TRANSPORT_CELLULAR)
+                        && !c.hasTransport(TRANSPORT_WIFI)
+                        && !c.hasCapability(NET_CAPABILITY_NOT_METERED)
+                        && c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
         // The return value of getDefaultNetworkCapabilitiesForUser always includes the default
         // network (wifi) as well as the underlying networks (cell).
         assertDefaultNetworkCapabilities(userId, mCellAgent, mWiFiAgent);
 
         // Remove NOT_SUSPENDED from the only network and observe VPN is now suspended.
         mCellAgent.removeCapability(NET_CAPABILITY_NOT_SUSPENDED);
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn,
-                (caps) -> caps.hasTransport(TRANSPORT_VPN)
-                && caps.hasTransport(TRANSPORT_CELLULAR) && !caps.hasTransport(TRANSPORT_WIFI)
-                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED)
-                && !caps.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
+        vpnNetworkCallback.expectCaps(mMockVpn,
+                c -> c.hasTransport(TRANSPORT_VPN)
+                        && c.hasTransport(TRANSPORT_CELLULAR)
+                        && !c.hasTransport(TRANSPORT_WIFI)
+                        && !c.hasCapability(NET_CAPABILITY_NOT_METERED)
+                        && !c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
         vpnNetworkCallback.expect(SUSPENDED, mMockVpn);
 
         // Add NOT_SUSPENDED again and observe VPN is no longer suspended.
         mCellAgent.addCapability(NET_CAPABILITY_NOT_SUSPENDED);
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn,
-                (caps) -> caps.hasTransport(TRANSPORT_VPN)
-                && caps.hasTransport(TRANSPORT_CELLULAR) && !caps.hasTransport(TRANSPORT_WIFI)
-                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED)
-                && caps.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
+        vpnNetworkCallback.expectCaps(mMockVpn,
+                c -> c.hasTransport(TRANSPORT_VPN)
+                        && c.hasTransport(TRANSPORT_CELLULAR)
+                        && !c.hasTransport(TRANSPORT_WIFI)
+                        && !c.hasCapability(NET_CAPABILITY_NOT_METERED)
+                        && c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
         vpnNetworkCallback.expect(RESUMED, mMockVpn);
 
         // Use Wifi but not cell. Note the VPN is now unmetered and not suspended.
         mMockVpn.setUnderlyingNetworks(new Network[] { mWiFiAgent.getNetwork() });
 
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn,
-                (caps) -> caps.hasTransport(TRANSPORT_VPN)
-                && !caps.hasTransport(TRANSPORT_CELLULAR) && caps.hasTransport(TRANSPORT_WIFI)
-                && caps.hasCapability(NET_CAPABILITY_NOT_METERED)
-                && caps.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
+        vpnNetworkCallback.expectCaps(mMockVpn,
+                c -> c.hasTransport(TRANSPORT_VPN)
+                        && !c.hasTransport(TRANSPORT_CELLULAR)
+                        && c.hasTransport(TRANSPORT_WIFI)
+                        && c.hasCapability(NET_CAPABILITY_NOT_METERED)
+                        && c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
         assertDefaultNetworkCapabilities(userId, mWiFiAgent);
 
         // Use both again.
         mMockVpn.setUnderlyingNetworks(
                 new Network[] { mCellAgent.getNetwork(), mWiFiAgent.getNetwork() });
 
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn,
-                (caps) -> caps.hasTransport(TRANSPORT_VPN)
-                && caps.hasTransport(TRANSPORT_CELLULAR) && caps.hasTransport(TRANSPORT_WIFI)
-                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED)
-                && caps.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
+        vpnNetworkCallback.expectCaps(mMockVpn,
+                c -> c.hasTransport(TRANSPORT_VPN)
+                        && c.hasTransport(TRANSPORT_CELLULAR)
+                        && c.hasTransport(TRANSPORT_WIFI)
+                        && !c.hasCapability(NET_CAPABILITY_NOT_METERED)
+                        && c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
         assertDefaultNetworkCapabilities(userId, mCellAgent, mWiFiAgent);
 
         // Cell is suspended again. As WiFi is not, this should not cause a callback.
@@ -8739,11 +8738,11 @@
 
         // Stop using WiFi. The VPN is suspended again.
         mMockVpn.setUnderlyingNetworks(new Network[] { mCellAgent.getNetwork() });
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn,
-                (caps) -> caps.hasTransport(TRANSPORT_VPN)
-                && caps.hasTransport(TRANSPORT_CELLULAR)
-                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED)
-                && !caps.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
+        vpnNetworkCallback.expectCaps(mMockVpn,
+                c -> c.hasTransport(TRANSPORT_VPN)
+                        && c.hasTransport(TRANSPORT_CELLULAR)
+                        && !c.hasCapability(NET_CAPABILITY_NOT_METERED)
+                        && !c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
         vpnNetworkCallback.expect(SUSPENDED, mMockVpn);
         assertDefaultNetworkCapabilities(userId, mCellAgent, mWiFiAgent);
 
@@ -8751,29 +8750,32 @@
         mMockVpn.setUnderlyingNetworks(
                 new Network[] { mCellAgent.getNetwork(), mWiFiAgent.getNetwork() });
 
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn,
-                (caps) -> caps.hasTransport(TRANSPORT_VPN)
-                && caps.hasTransport(TRANSPORT_CELLULAR) && caps.hasTransport(TRANSPORT_WIFI)
-                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED)
-                && caps.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
+        vpnNetworkCallback.expectCaps(mMockVpn,
+                c -> c.hasTransport(TRANSPORT_VPN)
+                        && c.hasTransport(TRANSPORT_CELLULAR)
+                        && c.hasTransport(TRANSPORT_WIFI)
+                        && !c.hasCapability(NET_CAPABILITY_NOT_METERED)
+                        && c.hasCapability(NET_CAPABILITY_NOT_SUSPENDED));
         vpnNetworkCallback.expect(RESUMED, mMockVpn);
         assertDefaultNetworkCapabilities(userId, mCellAgent, mWiFiAgent);
 
         // Disconnect cell. Receive update without even removing the dead network from the
         // underlying networks – it's dead anyway. Not metered any more.
         mCellAgent.disconnect();
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn,
-                (caps) -> caps.hasTransport(TRANSPORT_VPN)
-                && !caps.hasTransport(TRANSPORT_CELLULAR) && caps.hasTransport(TRANSPORT_WIFI)
-                && caps.hasCapability(NET_CAPABILITY_NOT_METERED));
+        vpnNetworkCallback.expectCaps(mMockVpn,
+                c -> c.hasTransport(TRANSPORT_VPN)
+                        && !c.hasTransport(TRANSPORT_CELLULAR)
+                        && c.hasTransport(TRANSPORT_WIFI)
+                        && c.hasCapability(NET_CAPABILITY_NOT_METERED));
         assertDefaultNetworkCapabilities(userId, mWiFiAgent);
 
         // Disconnect wifi too. No underlying networks means this is now metered.
         mWiFiAgent.disconnect();
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn,
-                (caps) -> caps.hasTransport(TRANSPORT_VPN)
-                && !caps.hasTransport(TRANSPORT_CELLULAR) && !caps.hasTransport(TRANSPORT_WIFI)
-                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED));
+        vpnNetworkCallback.expectCaps(mMockVpn,
+                c -> c.hasTransport(TRANSPORT_VPN)
+                        && !c.hasTransport(TRANSPORT_CELLULAR)
+                        && !c.hasTransport(TRANSPORT_WIFI)
+                        && !c.hasCapability(NET_CAPABILITY_NOT_METERED));
         // When a network disconnects, the callbacks are fired before all state is updated, so for a
         // short time, synchronous calls will behave as if the network is still connected. Wait for
         // things to settle.
@@ -8814,20 +8816,22 @@
         mCellAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR);
         mCellAgent.connect(true);
 
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn,
-                (caps) -> caps.hasTransport(TRANSPORT_VPN)
-                && caps.hasTransport(TRANSPORT_CELLULAR) && !caps.hasTransport(TRANSPORT_WIFI)
-                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED));
+        vpnNetworkCallback.expectCaps(mMockVpn,
+                c -> c.hasTransport(TRANSPORT_VPN)
+                        && c.hasTransport(TRANSPORT_CELLULAR)
+                        && !c.hasTransport(TRANSPORT_WIFI)
+                        && !c.hasCapability(NET_CAPABILITY_NOT_METERED));
 
         // Connect to WiFi; WiFi is the new default.
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
         mWiFiAgent.addCapability(NET_CAPABILITY_NOT_METERED);
         mWiFiAgent.connect(true);
 
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn,
-                (caps) -> caps.hasTransport(TRANSPORT_VPN)
-                && !caps.hasTransport(TRANSPORT_CELLULAR) && caps.hasTransport(TRANSPORT_WIFI)
-                && caps.hasCapability(NET_CAPABILITY_NOT_METERED));
+        vpnNetworkCallback.expectCaps(mMockVpn,
+                c -> c.hasTransport(TRANSPORT_VPN)
+                        && !c.hasTransport(TRANSPORT_CELLULAR)
+                        && c.hasTransport(TRANSPORT_WIFI)
+                        && c.hasCapability(NET_CAPABILITY_NOT_METERED));
 
         // Disconnect Cell. The default network did not change, so there shouldn't be any changes in
         // the capabilities.
@@ -8836,10 +8840,11 @@
         // Disconnect wifi too. Now we have no default network.
         mWiFiAgent.disconnect();
 
-        vpnNetworkCallback.expectCapabilitiesThat(mMockVpn,
-                (caps) -> caps.hasTransport(TRANSPORT_VPN)
-                && !caps.hasTransport(TRANSPORT_CELLULAR) && !caps.hasTransport(TRANSPORT_WIFI)
-                && !caps.hasCapability(NET_CAPABILITY_NOT_METERED));
+        vpnNetworkCallback.expectCaps(mMockVpn,
+                c -> c.hasTransport(TRANSPORT_VPN)
+                        && !c.hasTransport(TRANSPORT_CELLULAR)
+                        && !c.hasTransport(TRANSPORT_WIFI)
+                        && !c.hasCapability(NET_CAPABILITY_NOT_METERED));
 
         mMockVpn.disconnect();
     }
@@ -8871,11 +8876,9 @@
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
         mWiFiAgent.connect(true);
         callback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
-        callback.expectCapabilitiesThat(mMockVpn, (caps)
-                -> caps.hasTransport(TRANSPORT_VPN)
-                && caps.hasTransport(TRANSPORT_WIFI));
-        callback.expectCapabilitiesThat(mWiFiAgent, (caps)
-                -> caps.hasCapability(NET_CAPABILITY_VALIDATED));
+        callback.expectCaps(mMockVpn, c -> c.hasTransport(TRANSPORT_VPN)
+                        && c.hasTransport(TRANSPORT_WIFI));
+        callback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
 
         doReturn(UserHandle.getUid(RESTRICTED_USER, VPN_UID)).when(mPackageManager)
                 .getPackageUidAsUser(ALWAYS_ON_PACKAGE, RESTRICTED_USER);
@@ -8886,35 +8889,35 @@
         // Expect that the VPN UID ranges contain both |uid| and the UID range for the newly-added
         // restricted user.
         final UidRange rRange = UidRange.createForUser(UserHandle.of(RESTRICTED_USER));
-        final Range<Integer> restrictUidRange = new Range<Integer>(rRange.start, rRange.stop);
-        final Range<Integer> singleUidRange = new Range<Integer>(uid, uid);
-        callback.expectCapabilitiesThat(mMockVpn, (caps)
-                -> caps.getUids().size() == 2
-                && caps.getUids().contains(singleUidRange)
-                && caps.getUids().contains(restrictUidRange)
-                && caps.hasTransport(TRANSPORT_VPN)
-                && caps.hasTransport(TRANSPORT_WIFI));
+        final Range<Integer> restrictUidRange = new Range<>(rRange.start, rRange.stop);
+        final Range<Integer> singleUidRange = new Range<>(uid, uid);
+        callback.expectCaps(mMockVpn, c ->
+                c.getUids().size() == 2
+                && c.getUids().contains(singleUidRange)
+                && c.getUids().contains(restrictUidRange)
+                && c.hasTransport(TRANSPORT_VPN)
+                && c.hasTransport(TRANSPORT_WIFI));
 
         // Change the VPN's capabilities somehow (specifically, disconnect wifi).
         mWiFiAgent.disconnect();
         callback.expect(LOST, mWiFiAgent);
-        callback.expectCapabilitiesThat(mMockVpn, (caps)
-                -> caps.getUids().size() == 2
-                && caps.getUids().contains(singleUidRange)
-                && caps.getUids().contains(restrictUidRange)
-                && caps.hasTransport(TRANSPORT_VPN)
-                && !caps.hasTransport(TRANSPORT_WIFI));
+        callback.expectCaps(mMockVpn, c ->
+                c.getUids().size() == 2
+                && c.getUids().contains(singleUidRange)
+                && c.getUids().contains(restrictUidRange)
+                && c.hasTransport(TRANSPORT_VPN)
+                && !c.hasTransport(TRANSPORT_WIFI));
 
         // User removed and expect to lose the UID range for the restricted user.
         mMockVpn.onUserRemoved(RESTRICTED_USER);
 
         // Expect that the VPN gains the UID range for the restricted user, and that the capability
         // change made just before that (i.e., loss of TRANSPORT_WIFI) is preserved.
-        callback.expectCapabilitiesThat(mMockVpn, (caps)
-                -> caps.getUids().size() == 1
-                && caps.getUids().contains(singleUidRange)
-                && caps.hasTransport(TRANSPORT_VPN)
-                && !caps.hasTransport(TRANSPORT_WIFI));
+        callback.expectCaps(mMockVpn, c ->
+                c.getUids().size() == 1
+                && c.getUids().contains(singleUidRange)
+                && c.hasTransport(TRANSPORT_VPN)
+                && !c.hasTransport(TRANSPORT_WIFI));
     }
 
     @Test
@@ -9226,9 +9229,10 @@
 
         // Restrict the network based on UID rule and NOT_METERED capability change.
         mCellAgent.addCapability(NET_CAPABILITY_NOT_METERED);
-        cellNetworkCallback.expectCapabilitiesWith(NET_CAPABILITY_NOT_METERED, mCellAgent);
+        cellNetworkCallback.expectCaps(mCellAgent,
+                c -> c.hasCapability(NET_CAPABILITY_NOT_METERED));
         cellNetworkCallback.expectBlockedStatusCallback(false, mCellAgent);
-        detailedCallback.expectCapabilitiesWith(NET_CAPABILITY_NOT_METERED, mCellAgent);
+        detailedCallback.expectCaps(mCellAgent, c -> c.hasCapability(NET_CAPABILITY_NOT_METERED));
         detailedCallback.expectBlockedStatusCallback(mCellAgent, BLOCKED_REASON_NONE);
         assertEquals(mCellAgent.getNetwork(), mCm.getActiveNetwork());
         assertActiveNetworkInfo(TYPE_MOBILE, DetailedState.CONNECTED);
@@ -9236,9 +9240,11 @@
         assertExtraInfoFromCmPresent(mCellAgent);
 
         mCellAgent.removeCapability(NET_CAPABILITY_NOT_METERED);
-        cellNetworkCallback.expectCapabilitiesWithout(NET_CAPABILITY_NOT_METERED, mCellAgent);
+        cellNetworkCallback.expectCaps(mCellAgent,
+                c -> !c.hasCapability(NET_CAPABILITY_NOT_METERED));
         cellNetworkCallback.expectBlockedStatusCallback(true, mCellAgent);
-        detailedCallback.expectCapabilitiesWithout(NET_CAPABILITY_NOT_METERED, mCellAgent);
+        detailedCallback.expectCaps(mCellAgent,
+                c -> !c.hasCapability(NET_CAPABILITY_NOT_METERED));
         detailedCallback.expectBlockedStatusCallback(mCellAgent, BLOCKED_METERED_REASON_DATA_SAVER);
         assertNull(mCm.getActiveNetwork());
         assertActiveNetworkInfo(TYPE_MOBILE, DetailedState.BLOCKED);
@@ -9301,7 +9307,7 @@
         mCellAgent = new TestNetworkAgentWrapper(TRANSPORT_CELLULAR);
         mCellAgent.connect(true);
         defaultCallback.expectAvailableCallbacksUnvalidatedAndBlocked(mCellAgent);
-        defaultCallback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mCellAgent);
+        defaultCallback.expectCaps(mCellAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
 
         // Allow to use the network after switching to NOT_METERED network.
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI);
@@ -9316,7 +9322,7 @@
 
         // Network becomes NOT_METERED.
         mCellAgent.addCapability(NET_CAPABILITY_NOT_METERED);
-        defaultCallback.expectCapabilitiesWith(NET_CAPABILITY_NOT_METERED, mCellAgent);
+        defaultCallback.expectCaps(mCellAgent, c -> c.hasCapability(NET_CAPABILITY_NOT_METERED));
         defaultCallback.expectBlockedStatusCallback(false, mCellAgent);
 
         // Verify there's no Networkcallbacks invoked after data saver on/off.
@@ -9886,7 +9892,7 @@
         callback.expect(LOST, mWiFiAgent);
         systemDefaultCallback.expect(LOST, mWiFiAgent);
         b1.expectBroadcast();
-        callback.expectCapabilitiesThat(mMockVpn, nc -> !nc.hasTransport(TRANSPORT_WIFI));
+        callback.expectCaps(mMockVpn, c -> !c.hasTransport(TRANSPORT_WIFI));
         mMockVpn.expectStopVpnRunnerPrivileged();
         callback.expect(LOST, mMockVpn);
         b2.expectBroadcast();
@@ -10055,7 +10061,7 @@
             // changes back to cellular.
             mWiFiAgent.removeCapability(testCap);
             callbackWithCap.expectAvailableCallbacksValidated(mCellAgent);
-            callbackWithoutCap.expectCapabilitiesWithout(testCap, mWiFiAgent);
+            callbackWithoutCap.expectCaps(mWiFiAgent, c -> !c.hasCapability(testCap));
             verify(mMockNetd).networkSetDefault(eq(mCellAgent.getNetwork().netId));
             reset(mMockNetd);
 
@@ -10705,7 +10711,7 @@
         mWiFiAgent.connect(true);
         networkCallback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         networkCallback.expectLosing(mCellAgent);
-        networkCallback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        networkCallback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         verify(mMockNetd, times(1)).idletimerAddInterface(eq(WIFI_IFNAME), anyInt(),
                 eq(Integer.toString(TRANSPORT_WIFI)));
         verify(mMockNetd, times(1)).idletimerRemoveInterface(eq(MOBILE_IFNAME), anyInt(),
@@ -10729,7 +10735,7 @@
         mWiFiAgent.connect(true);
         networkCallback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         networkCallback.expectLosing(mCellAgent);
-        networkCallback.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        networkCallback.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         verify(mMockNetd, times(1)).idletimerAddInterface(eq(WIFI_IFNAME), anyInt(),
                 eq(Integer.toString(TRANSPORT_WIFI)));
         verify(mMockNetd, times(1)).idletimerRemoveInterface(eq(MOBILE_IFNAME), anyInt(),
@@ -11445,9 +11451,9 @@
         // callback.
         mWiFiAgent.setNetworkCapabilities(ncTemplate.setTransportInfo(actualTransportInfo), true);
 
-        wifiNetworkCallback.expectCapabilitiesThat(mWiFiAgent,
-                nc -> Objects.equals(expectedOwnerUid, nc.getOwnerUid())
-                        && Objects.equals(expectedTransportInfo, nc.getTransportInfo()));
+        wifiNetworkCallback.expectCaps(mWiFiAgent,
+                c -> Objects.equals(expectedOwnerUid, c.getOwnerUid())
+                        && Objects.equals(expectedTransportInfo, c.getTransportInfo()));
     }
 
     @Test
@@ -12241,7 +12247,7 @@
         assertNull(mService.getProxyForNetwork(null));
         assertNull(mCm.getDefaultProxy());
 
-        final ExpectedBroadcast b1 = registerPacProxyBroadcast();
+        final ExpectedBroadcast b1 = expectProxyChangeAction();
         final LinkProperties lp = new LinkProperties();
         lp.setInterfaceName("tun0");
         lp.addRoute(new RouteInfo(new IpPrefix(Inet4Address.ANY, 0), null));
@@ -12254,7 +12260,7 @@
         b1.expectNoBroadcast(500);
 
         // Update to new range which is old range minus APP1, i.e. only APP2
-        final ExpectedBroadcast b2 = registerPacProxyBroadcast();
+        final ExpectedBroadcast b2 = expectProxyChangeAction();
         final Set<UidRange> newRanges = new HashSet<>(asList(
                 new UidRange(vpnRange.start, APP1_UID - 1),
                 new UidRange(APP1_UID + 1, vpnRange.stop)));
@@ -12268,20 +12274,20 @@
         b2.expectNoBroadcast(500);
 
         final ProxyInfo testProxyInfo = ProxyInfo.buildDirectProxy("test", 8888);
-        final ExpectedBroadcast b3 = registerPacProxyBroadcast();
+        final ExpectedBroadcast b3 = expectProxyChangeAction();
         lp.setHttpProxy(testProxyInfo);
         mMockVpn.sendLinkProperties(lp);
         waitForIdle();
         // Proxy is set, so send a proxy broadcast.
         b3.expectBroadcast();
 
-        final ExpectedBroadcast b4 = registerPacProxyBroadcast();
+        final ExpectedBroadcast b4 = expectProxyChangeAction();
         mMockVpn.setUids(vpnRanges);
         waitForIdle();
         // Uid has changed and proxy is already set, so send a proxy broadcast.
         b4.expectBroadcast();
 
-        final ExpectedBroadcast b5 = registerPacProxyBroadcast();
+        final ExpectedBroadcast b5 = expectProxyChangeAction();
         // Proxy is removed, send a proxy broadcast.
         lp.setHttpProxy(null);
         mMockVpn.sendLinkProperties(lp);
@@ -12314,7 +12320,7 @@
         lp.setHttpProxy(testProxyInfo);
         final UidRange vpnRange = PRIMARY_UIDRANGE;
         final Set<UidRange> vpnRanges = Collections.singleton(vpnRange);
-        final ExpectedBroadcast b1 = registerPacProxyBroadcast();
+        final ExpectedBroadcast b1 = expectProxyChangeAction();
         mMockVpn.setOwnerAndAdminUid(VPN_UID);
         mMockVpn.registerAgent(false, vpnRanges, lp);
         // In any case, the proxy broadcast won't be sent before VPN goes into CONNECTED state.
@@ -12322,7 +12328,7 @@
         // proxy broadcast will get null.
         b1.expectNoBroadcast(500);
 
-        final ExpectedBroadcast b2 = registerPacProxyBroadcast();
+        final ExpectedBroadcast b2 = expectProxyChangeAction();
         mMockVpn.connect(true /* validated */, true /* hasInternet */,
                 false /* privateDnsProbeSent */);
         waitForIdle();
@@ -12358,7 +12364,7 @@
         final LinkProperties cellularLp = new LinkProperties();
         cellularLp.setInterfaceName(MOBILE_IFNAME);
         final ProxyInfo testProxyInfo = ProxyInfo.buildDirectProxy("test", 8888);
-        final ExpectedBroadcast b = registerPacProxyBroadcast();
+        final ExpectedBroadcast b = expectProxyChangeAction();
         cellularLp.setHttpProxy(testProxyInfo);
         mCellAgent.sendLinkProperties(cellularLp);
         b.expectBroadcast();
@@ -12403,7 +12409,7 @@
         // sees the network come up and validate later
         allNetworksCb.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         allNetworksCb.expectLosing(mCellAgent);
-        allNetworksCb.expectCapabilitiesWith(NET_CAPABILITY_VALIDATED, mWiFiAgent);
+        allNetworksCb.expectCaps(mWiFiAgent, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
         allNetworksCb.expect(LOST, mCellAgent, TEST_LINGER_DELAY_MS * 2);
 
         // The cell network has disconnected (see LOST above) because it was outscored and
@@ -14428,10 +14434,10 @@
         mDefaultNetworkCallback.expectAvailableThenValidatedCallbacks(mCellAgent);
 
         mCellAgent.addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED);
-        mSystemDefaultNetworkCallback.expectCapabilitiesThat(mCellAgent, nc ->
-                nc.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED));
-        mDefaultNetworkCallback.expectCapabilitiesThat(mCellAgent, nc ->
-                nc.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED));
+        mSystemDefaultNetworkCallback.expectCaps(mCellAgent,
+                c -> c.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED));
+        mDefaultNetworkCallback.expectCaps(mCellAgent,
+                c -> c.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED));
 
         // default callbacks will be unregistered in tearDown
     }
@@ -14818,20 +14824,19 @@
         // not to the other apps.
         workAgent.setNetworkValid(true /* privateDnsProbeSent */);
         workAgent.mNetworkMonitor.forceReevaluation(Process.myUid());
-        profileDefaultNetworkCallback.expectCapabilitiesThat(workAgent,
-                nc -> nc.hasCapability(NET_CAPABILITY_VALIDATED)
-                        && nc.hasCapability(NET_CAPABILITY_ENTERPRISE)
-                        && nc.hasEnterpriseId(
-                                profileNetworkPreference.getPreferenceEnterpriseId())
-                        && nc.getEnterpriseIds().length == 1);
+        profileDefaultNetworkCallback.expectCaps(workAgent,
+                c -> c.hasCapability(NET_CAPABILITY_VALIDATED)
+                        && c.hasCapability(NET_CAPABILITY_ENTERPRISE)
+                        && c.hasEnterpriseId(profileNetworkPreference.getPreferenceEnterpriseId())
+                        && c.getEnterpriseIds().length == 1);
         if (disAllowProfileDefaultNetworkCallback != null) {
             assertNoCallbacks(disAllowProfileDefaultNetworkCallback);
         }
         assertNoCallbacks(mSystemDefaultNetworkCallback, mDefaultNetworkCallback);
 
         workAgent.addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED);
-        profileDefaultNetworkCallback.expectCapabilitiesThat(workAgent, nc ->
-                nc.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED));
+        profileDefaultNetworkCallback.expectCaps(workAgent,
+                c -> c.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED));
         if (disAllowProfileDefaultNetworkCallback != null) {
             assertNoCallbacks(disAllowProfileDefaultNetworkCallback);
         }
@@ -14840,13 +14845,13 @@
         // Conversely, change a capability on the system-wide default network and make sure
         // that only the apps outside of the work profile receive the callbacks.
         mCellAgent.addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED);
-        mSystemDefaultNetworkCallback.expectCapabilitiesThat(mCellAgent, nc ->
-                nc.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED));
-        mDefaultNetworkCallback.expectCapabilitiesThat(mCellAgent, nc ->
-                nc.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED));
+        mSystemDefaultNetworkCallback.expectCaps(mCellAgent,
+                c -> c.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED));
+        mDefaultNetworkCallback.expectCaps(mCellAgent,
+                c -> c.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED));
         if (disAllowProfileDefaultNetworkCallback != null) {
-            disAllowProfileDefaultNetworkCallback.expectCapabilitiesThat(mCellAgent, nc ->
-                    nc.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED));
+            disAllowProfileDefaultNetworkCallback.expectCaps(mCellAgent,
+                    c -> c.hasCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED));
         }
         profileDefaultNetworkCallback.assertNoCallback();
 
@@ -14928,12 +14933,11 @@
 
         workAgent2.setNetworkValid(true /* privateDnsProbeSent */);
         workAgent2.mNetworkMonitor.forceReevaluation(Process.myUid());
-        profileDefaultNetworkCallback.expectCapabilitiesThat(workAgent2,
-                nc -> nc.hasCapability(NET_CAPABILITY_ENTERPRISE)
-                        && !nc.hasCapability(NET_CAPABILITY_NOT_RESTRICTED)
-                        && nc.hasEnterpriseId(
-                        profileNetworkPreference.getPreferenceEnterpriseId())
-                        && nc.getEnterpriseIds().length == 1);
+        profileDefaultNetworkCallback.expectCaps(workAgent2,
+                c -> c.hasCapability(NET_CAPABILITY_ENTERPRISE)
+                        && !c.hasCapability(NET_CAPABILITY_NOT_RESTRICTED)
+                        && c.hasEnterpriseId(profileNetworkPreference.getPreferenceEnterpriseId())
+                        && c.getEnterpriseIds().length == 1);
         if (disAllowProfileDefaultNetworkCallback != null) {
             assertNoCallbacks(disAllowProfileDefaultNetworkCallback);
         }
@@ -16119,7 +16123,7 @@
         nc.setAllowedUids(uids);
         agent.setNetworkCapabilities(nc, true /* sendToConnectivityService */);
         if (SdkLevel.isAtLeastT()) {
-            cb.expectCapabilitiesThat(agent, caps -> caps.getAllowedUids().equals(uids));
+            cb.expectCaps(agent, c -> c.getAllowedUids().equals(uids));
         } else {
             cb.assertNoCallback();
         }
@@ -16136,7 +16140,7 @@
         nc.setAllowedUids(uids);
         agent.setNetworkCapabilities(nc, true /* sendToConnectivityService */);
         if (SdkLevel.isAtLeastT()) {
-            cb.expectCapabilitiesThat(agent, caps -> caps.getAllowedUids().equals(uids));
+            cb.expectCaps(agent, c -> c.getAllowedUids().equals(uids));
             inOrder.verify(mMockNetd, times(1)).networkRemoveUidRangesParcel(uids200Parcel);
         } else {
             cb.assertNoCallback();
@@ -16147,7 +16151,7 @@
         nc.setAllowedUids(uids);
         agent.setNetworkCapabilities(nc, true /* sendToConnectivityService */);
         if (SdkLevel.isAtLeastT()) {
-            cb.expectCapabilitiesThat(agent, caps -> caps.getAllowedUids().equals(uids));
+            cb.expectCaps(agent, c -> c.getAllowedUids().equals(uids));
         } else {
             cb.assertNoCallback();
         }
@@ -16164,7 +16168,7 @@
         nc.setAllowedUids(uids);
         agent.setNetworkCapabilities(nc, true /* sendToConnectivityService */);
         if (SdkLevel.isAtLeastT()) {
-            cb.expectCapabilitiesThat(agent, caps -> caps.getAllowedUids().isEmpty());
+            cb.expectCaps(agent, c -> c.getAllowedUids().isEmpty());
             inOrder.verify(mMockNetd, times(1)).networkRemoveUidRangesParcel(uids600Parcel);
         } else {
             cb.assertNoCallback();
@@ -16217,8 +16221,7 @@
         ncb.setAllowedUids(serviceUidSet);
         mEthernetAgent.setNetworkCapabilities(ncb.build(), true /* sendToCS */);
         if (SdkLevel.isAtLeastT() && hasAutomotiveFeature) {
-            cb.expectCapabilitiesThat(mEthernetAgent,
-                    caps -> caps.getAllowedUids().equals(serviceUidSet));
+            cb.expectCaps(mEthernetAgent, c -> c.getAllowedUids().equals(serviceUidSet));
         } else {
             // S and no automotive feature must ignore access UIDs.
             cb.assertNoCallback(TEST_CALLBACK_TIMEOUT_MS);
@@ -16271,7 +16274,7 @@
         ncb.setAllowedUids(serviceUidSet);
         mCellAgent.setNetworkCapabilities(ncb.build(), true /* sendToCS */);
         if (SdkLevel.isAtLeastT()) {
-            cb.expectCapabilitiesThat(mCellAgent, cp -> cp.getAllowedUids().equals(serviceUidSet));
+            cb.expectCaps(mCellAgent, c -> c.getAllowedUids().equals(serviceUidSet));
         } else {
             // S must ignore access UIDs.
             cb.assertNoCallback(TEST_CALLBACK_TIMEOUT_MS);
@@ -16281,7 +16284,7 @@
         ncb.setAllowedUids(nonServiceUidSet);
         mCellAgent.setNetworkCapabilities(ncb.build(), true /* sendToCS */);
         if (SdkLevel.isAtLeastT()) {
-            cb.expectCapabilitiesThat(mCellAgent, cp -> cp.getAllowedUids().isEmpty());
+            cb.expectCaps(mCellAgent, c -> c.getAllowedUids().isEmpty());
         } else {
             // S must ignore access UIDs.
             cb.assertNoCallback(TEST_CALLBACK_TIMEOUT_MS);
@@ -17117,40 +17120,42 @@
             mWiFiAgent.setNetworkCapabilities(wifiNc2, true /* sendToConnectivityService */);
             // The only thing changed in this CAPS is the BSSID, which can't be tested for in this
             // test because it's redacted.
-            wifiNetworkCallback.expect(NETWORK_CAPS_UPDATED, mWiFiAgent);
-            mDefaultNetworkCallback.expect(NETWORK_CAPS_UPDATED, mWiFiAgent);
+            wifiNetworkCallback.expectCaps(mWiFiAgent);
+            mDefaultNetworkCallback.expectCaps(mWiFiAgent);
             mWiFiAgent.setNetworkPortal(TEST_REDIRECT_URL, false /* privateDnsProbeSent */);
             mCm.reportNetworkConnectivity(mWiFiAgent.getNetwork(), false);
             // Wi-Fi is now detected to have a portal : cell should become the default network.
             mDefaultNetworkCallback.expectAvailableCallbacksValidated(mCellAgent);
-            wifiNetworkCallback.expectCapabilitiesWithout(NET_CAPABILITY_VALIDATED, mWiFiAgent);
-            wifiNetworkCallback.expectCapabilitiesWith(NET_CAPABILITY_CAPTIVE_PORTAL, mWiFiAgent);
+            wifiNetworkCallback.expectCaps(mWiFiAgent,
+                    c -> !c.hasCapability(NET_CAPABILITY_VALIDATED));
+            wifiNetworkCallback.expectCaps(mWiFiAgent,
+                    c -> c.hasCapability(NET_CAPABILITY_CAPTIVE_PORTAL));
 
             // Wi-Fi becomes valid again. The default network goes back to Wi-Fi.
             mWiFiAgent.setNetworkValid(false /* privateDnsProbeSent */);
             mCm.reportNetworkConnectivity(mWiFiAgent.getNetwork(), true);
             mDefaultNetworkCallback.expectAvailableCallbacksValidated(mWiFiAgent);
-            wifiNetworkCallback.expectCapabilitiesWithout(NET_CAPABILITY_CAPTIVE_PORTAL,
-                    mWiFiAgent);
+            wifiNetworkCallback.expectCaps(mWiFiAgent,
+                    c -> !c.hasCapability(NET_CAPABILITY_CAPTIVE_PORTAL));
 
             // Wi-Fi roaming from wifiNc2 to wifiNc1, and the network now has partial connectivity.
             mWiFiAgent.setNetworkCapabilities(wifiNc1, true);
-            wifiNetworkCallback.expect(NETWORK_CAPS_UPDATED, mWiFiAgent);
-            mDefaultNetworkCallback.expect(NETWORK_CAPS_UPDATED, mWiFiAgent);
+            wifiNetworkCallback.expectCaps(mWiFiAgent);
+            mDefaultNetworkCallback.expectCaps(mWiFiAgent);
             mWiFiAgent.setNetworkPartial();
             mCm.reportNetworkConnectivity(mWiFiAgent.getNetwork(), false);
             // Wi-Fi now only offers partial connectivity, so in the absence of accepting partial
             // connectivity explicitly for this network, it loses default status to cell.
             mDefaultNetworkCallback.expectAvailableCallbacksValidated(mCellAgent);
-            wifiNetworkCallback.expectCapabilitiesWith(NET_CAPABILITY_PARTIAL_CONNECTIVITY,
-                    mWiFiAgent);
+            wifiNetworkCallback.expectCaps(mWiFiAgent,
+                    c -> c.hasCapability(NET_CAPABILITY_PARTIAL_CONNECTIVITY));
 
             // Wi-Fi becomes valid again. The default network goes back to Wi-Fi.
             mWiFiAgent.setNetworkValid(false /* privateDnsProbeSent */);
             mCm.reportNetworkConnectivity(mWiFiAgent.getNetwork(), true);
             mDefaultNetworkCallback.expectAvailableCallbacksValidated(mWiFiAgent);
-            wifiNetworkCallback.expectCapabilitiesWithout(NET_CAPABILITY_PARTIAL_CONNECTIVITY,
-                    mWiFiAgent);
+            wifiNetworkCallback.expectCaps(mWiFiAgent,
+                    c -> !c.hasCapability(NET_CAPABILITY_PARTIAL_CONNECTIVITY));
         }
         mCm.unregisterNetworkCallback(wifiNetworkCallback);
 
@@ -17158,7 +17163,7 @@
         // failures after roam are not ignored, this will cause cell to become the default network.
         // If they are ignored, this will not cause a switch until later.
         mWiFiAgent.setNetworkCapabilities(wifiNc2, true);
-        mDefaultNetworkCallback.expect(NETWORK_CAPS_UPDATED, mWiFiAgent);
+        mDefaultNetworkCallback.expectCaps(mWiFiAgent);
         mWiFiAgent.setNetworkInvalid(false /* invalidBecauseOfPrivateDns */);
         mCm.reportNetworkConnectivity(mWiFiAgent.getNetwork(), false);
 
diff --git a/tests/unit/java/com/android/server/IpSecServiceTest.java b/tests/unit/java/com/android/server/IpSecServiceTest.java
index 6955620..4b6857c 100644
--- a/tests/unit/java/com/android/server/IpSecServiceTest.java
+++ b/tests/unit/java/com/android/server/IpSecServiceTest.java
@@ -82,7 +82,7 @@
     private static final int MAX_NUM_ENCAP_SOCKETS = 100;
     private static final int MAX_NUM_SPIS = 100;
     private static final int TEST_UDP_ENCAP_INVALID_PORT = 100;
-    private static final int TEST_UDP_ENCAP_PORT_OUT_RANGE = 100000;
+    private static final int TEST_UDP_ENCAP_PORT_OUT_RANGE = 200000;
 
     private static final InetAddress INADDR_ANY;
 
diff --git a/tests/unit/java/com/android/server/NsdServiceTest.java b/tests/unit/java/com/android/server/NsdServiceTest.java
index 98a8ed2..5a3bc64 100644
--- a/tests/unit/java/com/android/server/NsdServiceTest.java
+++ b/tests/unit/java/com/android/server/NsdServiceTest.java
@@ -45,6 +45,7 @@
 import static org.mockito.Mockito.timeout;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
 import android.compat.testing.PlatformCompatChangeRule;
@@ -170,6 +171,9 @@
         doReturn(true).when(mMockMDnsM).resolve(
                 anyInt(), anyString(), anyString(), anyString(), anyInt());
         doReturn(false).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class));
+        doReturn(mDiscoveryManager).when(mDeps).makeMdnsDiscoveryManager(any(), any());
+        doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
+        doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any());
 
         mService = makeService();
     }
@@ -625,7 +629,7 @@
         waitForIdle();
 
         verify(mMockMDnsM).stopOperation(resolveId);
-        verify(resolveListener, timeout(TIMEOUT_MS)).onResolveStopped(argThat(ns ->
+        verify(resolveListener, timeout(TIMEOUT_MS)).onResolutionStopped(argThat(ns ->
                 request.getServiceName().equals(ns.getServiceName())
                         && request.getServiceType().equals(ns.getServiceType())));
     }
@@ -692,7 +696,7 @@
         waitForIdle();
 
         verify(mMockMDnsM).stopOperation(getAddrId);
-        verify(resolveListener, timeout(TIMEOUT_MS)).onResolveStopped(argThat(ns ->
+        verify(resolveListener, timeout(TIMEOUT_MS)).onResolutionStopped(argThat(ns ->
                 request.getServiceName().equals(ns.getServiceName())
                         && request.getServiceType().equals(ns.getServiceType())));
     }
@@ -824,40 +828,50 @@
                 client.unregisterServiceInfoCallback(serviceInfoCallback));
     }
 
-    private void makeServiceWithMdnsDiscoveryManagerEnabled() {
+    private void setMdnsDiscoveryManagerEnabled() {
         doReturn(true).when(mDeps).isMdnsDiscoveryManagerEnabled(any(Context.class));
-        doReturn(mDiscoveryManager).when(mDeps).makeMdnsDiscoveryManager(any(), any());
-        doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
-
-        mService = makeService();
-        verify(mDeps).makeMdnsDiscoveryManager(any(), any());
-        verify(mDeps).makeMdnsSocketProvider(any(), any());
     }
 
-    private void makeServiceWithMdnsAdvertiserEnabled() {
+    private void setMdnsAdvertiserEnabled() {
         doReturn(true).when(mDeps).isMdnsAdvertiserEnabled(any(Context.class));
-        doReturn(mAdvertiser).when(mDeps).makeMdnsAdvertiser(any(), any(), any());
-        doReturn(mSocketProvider).when(mDeps).makeMdnsSocketProvider(any(), any());
-
-        mService = makeService();
-        verify(mDeps).makeMdnsAdvertiser(any(), any(), any());
-        verify(mDeps).makeMdnsSocketProvider(any(), any());
     }
 
     @Test
     public void testMdnsDiscoveryManagerFeature() {
         // Create NsdService w/o feature enabled.
-        connectClient(mService);
-        verify(mDeps, never()).makeMdnsDiscoveryManager(any(), any());
-        verify(mDeps, never()).makeMdnsSocketProvider(any(), any());
+        final NsdManager client = connectClient(mService);
+        final DiscoveryListener discListenerWithoutFeature = mock(DiscoveryListener.class);
+        client.discoverServices(SERVICE_TYPE, PROTOCOL, discListenerWithoutFeature);
+        waitForIdle();
 
-        // Create NsdService again w/ feature enabled.
-        makeServiceWithMdnsDiscoveryManagerEnabled();
+        final ArgumentCaptor<Integer> legacyIdCaptor = ArgumentCaptor.forClass(Integer.class);
+        verify(mMockMDnsM).discover(legacyIdCaptor.capture(), any(), anyInt());
+        verifyNoMoreInteractions(mDiscoveryManager);
+
+        setMdnsDiscoveryManagerEnabled();
+        final DiscoveryListener discListenerWithFeature = mock(DiscoveryListener.class);
+        client.discoverServices(SERVICE_TYPE, PROTOCOL, discListenerWithFeature);
+        waitForIdle();
+
+        final String serviceTypeWithLocalDomain = SERVICE_TYPE + ".local";
+        final ArgumentCaptor<MdnsServiceBrowserListener> listenerCaptor =
+                ArgumentCaptor.forClass(MdnsServiceBrowserListener.class);
+        verify(mDiscoveryManager).registerListener(eq(serviceTypeWithLocalDomain),
+                listenerCaptor.capture(), any());
+
+        client.stopServiceDiscovery(discListenerWithoutFeature);
+        waitForIdle();
+        verify(mMockMDnsM).stopOperation(legacyIdCaptor.getValue());
+
+        client.stopServiceDiscovery(discListenerWithFeature);
+        waitForIdle();
+        verify(mDiscoveryManager).unregisterListener(serviceTypeWithLocalDomain,
+                listenerCaptor.getValue());
     }
 
     @Test
     public void testDiscoveryWithMdnsDiscoveryManager() {
-        makeServiceWithMdnsDiscoveryManagerEnabled();
+        setMdnsDiscoveryManagerEnabled();
 
         final NsdManager client = connectClient(mService);
         final DiscoveryListener discListener = mock(DiscoveryListener.class);
@@ -922,7 +936,7 @@
 
     @Test
     public void testDiscoveryWithMdnsDiscoveryManager_FailedWithInvalidServiceType() {
-        makeServiceWithMdnsDiscoveryManagerEnabled();
+        setMdnsDiscoveryManagerEnabled();
 
         final NsdManager client = connectClient(mService);
         final DiscoveryListener discListener = mock(DiscoveryListener.class);
@@ -951,7 +965,7 @@
 
     @Test
     public void testResolutionWithMdnsDiscoveryManager() throws UnknownHostException {
-        makeServiceWithMdnsDiscoveryManagerEnabled();
+        setMdnsDiscoveryManagerEnabled();
 
         final NsdManager client = connectClient(mService);
         final ResolveListener resolveListener = mock(ResolveListener.class);
@@ -1005,8 +1019,43 @@
     }
 
     @Test
+    public void testMdnsAdvertiserFeatureFlagging() {
+        // Create NsdService w/o feature enabled.
+        final NsdManager client = connectClient(mService);
+        final NsdServiceInfo regInfo = new NsdServiceInfo(SERVICE_NAME, SERVICE_TYPE);
+        regInfo.setHost(parseNumericAddress("192.0.2.123"));
+        regInfo.setPort(12345);
+        final RegistrationListener regListenerWithoutFeature = mock(RegistrationListener.class);
+        client.registerService(regInfo, PROTOCOL, regListenerWithoutFeature);
+        waitForIdle();
+
+        final ArgumentCaptor<Integer> legacyIdCaptor = ArgumentCaptor.forClass(Integer.class);
+        verify(mMockMDnsM).registerService(legacyIdCaptor.capture(), any(), any(), anyInt(),
+                any(), anyInt());
+        verifyNoMoreInteractions(mAdvertiser);
+
+        setMdnsAdvertiserEnabled();
+        final RegistrationListener regListenerWithFeature = mock(RegistrationListener.class);
+        client.registerService(regInfo, PROTOCOL, regListenerWithFeature);
+        waitForIdle();
+
+        final ArgumentCaptor<Integer> serviceIdCaptor = ArgumentCaptor.forClass(Integer.class);
+        verify(mAdvertiser).addService(serviceIdCaptor.capture(),
+                argThat(info -> matches(info, regInfo)));
+
+        client.unregisterService(regListenerWithoutFeature);
+        waitForIdle();
+        verify(mMockMDnsM).stopOperation(legacyIdCaptor.getValue());
+        verify(mAdvertiser, never()).removeService(anyInt());
+
+        client.unregisterService(regListenerWithFeature);
+        waitForIdle();
+        verify(mAdvertiser).removeService(serviceIdCaptor.getValue());
+    }
+
+    @Test
     public void testAdvertiseWithMdnsAdvertiser() {
-        makeServiceWithMdnsAdvertiserEnabled();
+        setMdnsAdvertiserEnabled();
 
         final NsdManager client = connectClient(mService);
         final RegistrationListener regListener = mock(RegistrationListener.class);
@@ -1045,7 +1094,7 @@
 
     @Test
     public void testAdvertiseWithMdnsAdvertiser_FailedWithInvalidServiceType() {
-        makeServiceWithMdnsAdvertiserEnabled();
+        setMdnsAdvertiserEnabled();
 
         final NsdManager client = connectClient(mService);
         final RegistrationListener regListener = mock(RegistrationListener.class);
@@ -1070,7 +1119,7 @@
 
     @Test
     public void testAdvertiseWithMdnsAdvertiser_LongServiceName() {
-        makeServiceWithMdnsAdvertiserEnabled();
+        setMdnsAdvertiserEnabled();
 
         final NsdManager client = connectClient(mService);
         final RegistrationListener regListener = mock(RegistrationListener.class);
diff --git a/tests/unit/java/com/android/server/connectivity/NetworkNotificationManagerTest.java b/tests/unit/java/com/android/server/connectivity/NetworkNotificationManagerTest.java
index 9a5298d..e038c44 100644
--- a/tests/unit/java/com/android/server/connectivity/NetworkNotificationManagerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/NetworkNotificationManagerTest.java
@@ -56,8 +56,10 @@
 import android.net.NetworkInfo;
 import android.os.Build;
 import android.os.Bundle;
+import android.os.PowerManager;
 import android.os.UserHandle;
 import android.telephony.TelephonyManager;
+import android.testing.PollingCheck;
 import android.util.DisplayMetrics;
 import android.widget.TextView;
 
@@ -391,7 +393,15 @@
 
         final Instrumentation instr = InstrumentationRegistry.getInstrumentation();
         final UiDevice uiDevice =  UiDevice.getInstance(instr);
-        UiDevice.getInstance(instr).pressHome();
+        final Context ctx = instr.getContext();
+        final PowerManager pm = ctx.getSystemService(PowerManager.class);
+
+        // Wake up the device (it has no effect if the device is already awake).
+        uiDevice.executeShellCommand("input keyevent KEYCODE_WAKEUP");
+        uiDevice.executeShellCommand("wm dismiss-keyguard");
+        PollingCheck.check("Wait for the screen to be turned on failed, timeout=" + TEST_TIMEOUT_MS,
+                TEST_TIMEOUT_MS, () -> pm.isInteractive());
+        uiDevice.pressHome();
 
         // UiDevice.getLauncherPackageName() requires the test manifest to have a <queries> tag for
         // the launcher intent.
@@ -404,7 +414,6 @@
         // Non-"no internet" notifications are not affected
         verify(mNotificationManager).notify(eq(TEST_NOTIF_TAG), eq(NETWORK_SWITCH.eventId), any());
 
-        final Context ctx = instr.getContext();
         final String testAction = "com.android.connectivity.coverage.TEST_DIALOG";
         final Intent intent = new Intent(testAction)
                 .addFlags(Intent.FLAG_ACTIVITY_NEW_TASK)
diff --git a/tools/gn2bp/Android.bp.swp b/tools/gn2bp/Android.bp.swp
index 9f34b06..9cc72a1 100644
--- a/tools/gn2bp/Android.bp.swp
+++ b/tools/gn2bp/Android.bp.swp
@@ -10821,15 +10821,11 @@
         "SPDX-license-identifier-Apache-2.0",
         "SPDX-license-identifier-BSD",
         "SPDX-license-identifier-BSL-1.0",
-        "SPDX-license-identifier-GPL",
-        "SPDX-license-identifier-GPL-2.0",
-        "SPDX-license-identifier-GPL-3.0",
         "SPDX-license-identifier-ICU",
         "SPDX-license-identifier-ISC",
-        "SPDX-license-identifier-LGPL",
-        "SPDX-license-identifier-LGPL-2.1",
         "SPDX-license-identifier-MIT",
         "SPDX-license-identifier-MPL",
+        "SPDX-license-identifier-MPL-1.1",
         "SPDX-license-identifier-MPL-2.0",
         "SPDX-license-identifier-NCSA",
         "SPDX-license-identifier-OpenSSL",
diff --git a/tools/gn2bp/gen_android_bp b/tools/gn2bp/gen_android_bp
index 9d2d858..d8bc116 100755
--- a/tools/gn2bp/gen_android_bp
+++ b/tools/gn2bp/gen_android_bp
@@ -1667,24 +1667,20 @@
 def create_license_module(blueprint):
   module = Module("license", "external_cronet_license", "LICENSE")
   module.license_kinds.update({
-      'SPDX-license-identifier-LGPL-2.1',
-      'SPDX-license-identifier-GPL-2.0',
       'SPDX-license-identifier-MPL',
+      'SPDX-license-identifier-MPL-1.1',
       'SPDX-license-identifier-ISC',
-      'SPDX-license-identifier-GPL',
       'SPDX-license-identifier-AFL-2.0',
       'SPDX-license-identifier-MPL-2.0',
       'SPDX-license-identifier-BSD',
       'SPDX-license-identifier-Apache-2.0',
       'SPDX-license-identifier-BSL-1.0',
-      'SPDX-license-identifier-LGPL',
-      'SPDX-license-identifier-GPL-3.0',
       'SPDX-license-identifier-Unicode-DFS',
       'SPDX-license-identifier-NCSA',
       'SPDX-license-identifier-OpenSSL',
       'SPDX-license-identifier-MIT',
       "SPDX-license-identifier-ICU",
-      'legacy_unencumbered', # public domain
+      'legacy_unencumbered',
   })
   module.license_text.update({
       "LICENSE",