Merge "NetBpfLoader: create /sys/fs/bpf/loader dir" into main
diff --git a/framework/Android.bp b/framework/Android.bp
index 8fa336a..f76bbe1 100644
--- a/framework/Android.bp
+++ b/framework/Android.bp
@@ -96,7 +96,6 @@
     ],
     impl_only_static_libs: [
         "net-utils-device-common-bpf",
-        "net-utils-device-common-struct",
     ],
     libs: [
         "androidx.annotation_annotation",
@@ -125,7 +124,6 @@
         // Even if the library is included in "impl_only_static_libs" of defaults. This is still
         // needed because java_library which doesn't understand "impl_only_static_libs".
         "net-utils-device-common-bpf",
-        "net-utils-device-common-struct",
     ],
     libs: [
         // This cannot be in the defaults clause above because if it were, it would be used
diff --git a/service-t/src/com/android/server/NsdService.java b/service-t/src/com/android/server/NsdService.java
index 9ba49d2..c045eaf 100644
--- a/service-t/src/com/android/server/NsdService.java
+++ b/service-t/src/com/android/server/NsdService.java
@@ -92,6 +92,7 @@
 import com.android.modules.utils.build.SdkLevel;
 import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.DeviceConfigUtils;
+import com.android.net.module.util.HandlerUtils;
 import com.android.net.module.util.InetAddressUtils;
 import com.android.net.module.util.PermissionUtils;
 import com.android.net.module.util.SharedLog;
@@ -2510,6 +2511,14 @@
         pw.increaseIndent();
         mServiceLogs.reverseDump(pw);
         pw.decreaseIndent();
+
+        //Dump DiscoveryManager
+        pw.println();
+        pw.println("DiscoveryManager:");
+        pw.increaseIndent();
+        HandlerUtils.runWithScissorsForDump(
+                mNsdStateMachine.getHandler(), () -> mMdnsDiscoveryManager.dump(pw), 10_000);
+        pw.decreaseIndent();
     }
 
     private abstract static class ClientRequest {
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 1d6039c..21b7069 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -34,6 +34,7 @@
 import com.android.server.connectivity.mdns.util.MdnsUtils;
 
 import java.io.IOException;
+import java.io.PrintWriter;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
@@ -363,4 +364,18 @@
                 executorProvider.newServiceTypeClientSchedulerExecutor(), socketKey,
                 sharedLog.forSubComponent(tag), looper, serviceCache);
     }
+
+    /**
+     * Dump DiscoveryManager state.
+     */
+    public void dump(PrintWriter pw) {
+        discoveryExecutor.checkAndRunOnHandlerThread(() -> {
+            pw.println();
+            // Dump ServiceTypeClients
+            for (MdnsServiceTypeClient serviceTypeClient
+                    : perSocketServiceTypeClients.getAllMdnsServiceTypeClient()) {
+                serviceTypeClient.dump(pw);
+            }
+        });
+    }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index e222fcf..33b5ea4 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -35,6 +35,7 @@
 import com.android.net.module.util.SharedLog;
 import com.android.server.connectivity.mdns.util.MdnsUtils;
 
+import java.io.PrintWriter;
 import java.net.Inet4Address;
 import java.net.Inet6Address;
 import java.util.ArrayList;
@@ -756,4 +757,13 @@
                 args.sessionId, timeToNextTasksWithBackoffInMs));
         return timeToNextTasksWithBackoffInMs;
     }
+
+    /**
+     * Dump ServiceTypeClient state.
+     */
+    public void dump(PrintWriter pw) {
+        ensureRunningOnHandlerThread(handler);
+        pw.println("ServiceTypeClient: Type{" + serviceType + "} " + socketKey + " with "
+                + listeners.size() + " listeners.");
+    }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/net/NetworkStatsService.java b/service-t/src/com/android/server/net/NetworkStatsService.java
index 80c4033..9684d18 100644
--- a/service-t/src/com/android/server/net/NetworkStatsService.java
+++ b/service-t/src/com/android/server/net/NetworkStatsService.java
@@ -2231,7 +2231,7 @@
                             .setDefaultNetwork(true)
                             .setOemManaged(ident.getOemManaged())
                             .setSubId(ident.getSubId()).build();
-                    final String ifaceVt = IFACE_VT + getSubIdForMobile(snapshot);
+                    final String ifaceVt = IFACE_VT + getSubIdForCellularOrSatellite(snapshot);
                     findOrCreateNetworkIdentitySet(mActiveIfaces, ifaceVt).add(vtIdent);
                     findOrCreateNetworkIdentitySet(mActiveUidIfaces, ifaceVt).add(vtIdent);
                 }
@@ -2300,9 +2300,15 @@
         mMobileIfaces = mobileIfaces.toArray(new String[0]);
     }
 
-    private static int getSubIdForMobile(@NonNull NetworkStateSnapshot state) {
-        if (!state.getNetworkCapabilities().hasTransport(NetworkCapabilities.TRANSPORT_CELLULAR)) {
-            throw new IllegalArgumentException("Mobile state need capability TRANSPORT_CELLULAR");
+    private static int getSubIdForCellularOrSatellite(@NonNull NetworkStateSnapshot state) {
+        if (!state.getNetworkCapabilities().hasTransport(NetworkCapabilities.TRANSPORT_CELLULAR)
+                // Both cellular and satellite are 2 different network transport at Mobile using
+                // same telephony network specifier. So adding satellite transport to consider
+                // for, when satellite network is active at mobile.
+                && !state.getNetworkCapabilities().hasTransport(
+                NetworkCapabilities.TRANSPORT_SATELLITE)) {
+            throw new IllegalArgumentException(
+                    "Mobile state need capability TRANSPORT_CELLULAR or TRANSPORT_SATELLITE");
         }
 
         final NetworkSpecifier spec = state.getNetworkCapabilities().getNetworkSpecifier();
diff --git a/service/Android.bp b/service/Android.bp
index 403ba7d..c35c4f8 100644
--- a/service/Android.bp
+++ b/service/Android.bp
@@ -199,6 +199,7 @@
         "PlatformProperties",
         "service-connectivity-protos",
         "service-connectivity-stats-protos",
+        "net-utils-multicast-forwarding-structs",
     ],
     apex_available: [
         "com.android.tethering",
diff --git a/staticlibs/Android.bp b/staticlibs/Android.bp
index 47e897d..6790093 100644
--- a/staticlibs/Android.bp
+++ b/staticlibs/Android.bp
@@ -188,6 +188,33 @@
     },
 }
 
+// The net-utils-multicast-forwarding-structs library requires the callers to
+// contain net-utils-device-common-bpf.
+java_library {
+    name: "net-utils-multicast-forwarding-structs",
+    srcs: [
+        "device/com/android/net/module/util/structs/StructMf6cctl.java",
+        "device/com/android/net/module/util/structs/StructMif6ctl.java",
+        "device/com/android/net/module/util/structs/StructMrt6Msg.java",
+    ],
+    sdk_version: "module_current",
+    min_sdk_version: "30",
+    visibility: [
+        "//packages/modules/Connectivity:__subpackages__",
+    ],
+    libs: [
+        // Only Struct.java is needed from "net-utils-device-common-bpf"
+        "net-utils-device-common-bpf",
+    ],
+    apex_available: [
+        "com.android.tethering",
+    ],
+    lint: {
+        strict_updatability_linting: true,
+        error_checks: ["NewApi"],
+    },
+}
+
 // The net-utils-device-common-netlink library requires the callers to contain
 // net-utils-device-common-struct.
 java_library {
diff --git a/tests/integration/AndroidManifest.xml b/tests/integration/AndroidManifest.xml
index cea83c7..1821329 100644
--- a/tests/integration/AndroidManifest.xml
+++ b/tests/integration/AndroidManifest.xml
@@ -42,6 +42,9 @@
     <uses-permission android:name="android.permission.QUERY_ALL_PACKAGES"/>
     <!-- Register UidFrozenStateChangedCallback -->
     <uses-permission android:name="android.permission.PACKAGE_USAGE_STATS"/>
+    <!-- Permission required for CTS test - NetworkStatsIntegrationTest -->
+    <uses-permission android:name="android.permission.INTERNET"/>
+    <uses-permission android:name="android.permission.CONNECTIVITY_USE_RESTRICTED_NETWORKS"/>
     <application android:debuggable="true">
         <uses-library android:name="android.test.runner"/>
 
diff --git a/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt b/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt
new file mode 100644
index 0000000..765e56e
--- /dev/null
+++ b/tests/integration/src/com/android/server/net/integrationtests/NetworkStatsIntegrationTest.kt
@@ -0,0 +1,597 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License
+ */
+
+package com.android.server.net.integrationtests
+
+import android.Manifest.permission.MANAGE_TEST_NETWORKS
+import android.annotation.TargetApi
+import android.app.usage.NetworkStats
+import android.app.usage.NetworkStats.Bucket
+import android.app.usage.NetworkStats.Bucket.TAG_NONE
+import android.app.usage.NetworkStatsManager
+import android.content.Context
+import android.net.ConnectivityManager
+import android.net.ConnectivityManager.TYPE_TEST
+import android.net.InetAddresses
+import android.net.IpPrefix
+import android.net.LinkAddress
+import android.net.Network
+import android.net.NetworkCapabilities
+import android.net.NetworkRequest
+import android.net.NetworkTemplate
+import android.net.NetworkTemplate.MATCH_TEST
+import android.net.TestNetworkSpecifier
+import android.net.TrafficStats
+import android.os.Build
+import android.os.Process
+import androidx.test.platform.app.InstrumentationRegistry
+import com.android.server.net.integrationtests.NetworkStatsIntegrationTest.Direction.DOWNLOAD
+import com.android.server.net.integrationtests.NetworkStatsIntegrationTest.Direction.UPLOAD
+import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
+import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.PacketBridge
+import com.android.testutils.RecorderCallback.CallbackEntry.LinkPropertiesChanged
+import com.android.testutils.TestDnsServer
+import com.android.testutils.TestHttpServer
+import com.android.testutils.TestableNetworkCallback
+import com.android.testutils.runAsShell
+import fi.iki.elonen.NanoHTTPD
+import java.io.BufferedInputStream
+import java.io.BufferedOutputStream
+import java.net.HttpURLConnection
+import java.net.HttpURLConnection.HTTP_OK
+import java.net.InetSocketAddress
+import java.net.URL
+import java.nio.charset.Charset
+import kotlin.math.ceil
+import kotlin.test.assertEquals
+import kotlin.test.assertTrue
+import org.junit.After
+import org.junit.Assume.assumeTrue
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+
+private const val TEST_TAG = 0xF00D
+
+@RunWith(DevSdkIgnoreRunner::class)
+@TargetApi(Build.VERSION_CODES.S)
+@IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+class NetworkStatsIntegrationTest {
+    private val TAG = NetworkStatsIntegrationTest::class.java.simpleName
+    private val LOCAL_V6ADDR =
+        LinkAddress(InetAddresses.parseNumericAddress("2001:db8::1234"), 64)
+
+    // Remote address, both the client and server will have a hallucination that
+    // they are talking to this address.
+    private val REMOTE_V6ADDR =
+        LinkAddress(InetAddresses.parseNumericAddress("dead:beef::808:808"), 64)
+    private val REMOTE_V4ADDR =
+        LinkAddress(InetAddresses.parseNumericAddress("8.8.8.8"), 32)
+    private val DEFAULT_MTU = 1500
+    private val DEFAULT_BUFFER_SIZE = 1500 // Any size greater than or equal to mtu
+    private val CONNECTION_TIMEOUT_MILLIS = 15000
+    private val TEST_DOWNLOAD_SIZE = 10000L
+    private val TEST_UPLOAD_SIZE = 20000L
+    private val HTTP_SERVER_NAME = "test.com"
+    private val HTTP_SERVER_PORT = 8080 // Use port > 1024 to avoid restrictions on system ports
+    private val DNS_INTERNAL_SERVER_PORT = 53
+    private val DNS_EXTERNAL_SERVER_PORT = 1053
+    private val TCP_ACK_SIZE = 72
+
+    // Packet overheads that are not part of the actual data transmission, these
+    // include DNS packets, TCP handshake/termination packets, and HTTP header
+    // packets. These overheads were gathered from real samples and may not
+    // be perfectly accurate because of DNS caches and TCP retransmissions, etc.
+    private val CONSTANT_PACKET_OVERHEAD = 8
+
+    // 130 is an observed average.
+    private val CONSTANT_BYTES_OVERHEAD = 130 * CONSTANT_PACKET_OVERHEAD
+    private val TOLERANCE = 1.3
+
+    // Set up the packet bridge with two IPv6 address only test networks.
+    private val inst = InstrumentationRegistry.getInstrumentation()
+    private val context = inst.getContext()
+    private val packetBridge = runAsShell(MANAGE_TEST_NETWORKS) {
+        PacketBridge(
+            context,
+            listOf(LOCAL_V6ADDR),
+            REMOTE_V6ADDR.address,
+            listOf(
+                Pair(DNS_INTERNAL_SERVER_PORT, DNS_EXTERNAL_SERVER_PORT)
+            )
+        )
+    }
+    private val cm = context.getSystemService(ConnectivityManager::class.java)!!
+
+    // Set up DNS server for testing server and DNS64.
+    private val fakeDns = TestDnsServer(
+        packetBridge.externalNetwork,
+        InetSocketAddress(LOCAL_V6ADDR.address, DNS_EXTERNAL_SERVER_PORT)
+    ).apply {
+        start()
+        setAnswer(
+            "ipv4only.arpa",
+            listOf(IpPrefix(REMOTE_V6ADDR.address, REMOTE_V6ADDR.prefixLength).address)
+        )
+        setAnswer(HTTP_SERVER_NAME, listOf(REMOTE_V4ADDR.address))
+    }
+
+    // Start up test http server.
+    private val httpServer = TestHttpServer(
+        LOCAL_V6ADDR.address.hostAddress,
+        HTTP_SERVER_PORT
+    ).apply {
+        start()
+    }
+
+    @Before
+    fun setUp() {
+        assumeTrue(shouldRunTests())
+        packetBridge.start()
+    }
+
+    // For networkstack tests, it is not guaranteed that the tethering module will be
+    // updated at the same time. If the tethering module is not new enough, it may not contain
+    // the necessary abilities to run these tests. For example, The tests depends on test
+    // network stats being counted, which can only be achieved when they are marked as TYPE_TEST.
+    // If the tethering module does not support TYPE_TEST stats, then these tests will need
+    // to be skipped.
+    fun shouldRunTests() = cm.getNetworkInfo(packetBridge.internalNetwork)!!.type == TYPE_TEST
+
+    @After
+    fun tearDown() {
+        packetBridge.stop()
+        fakeDns.stop()
+        httpServer.stop()
+    }
+
+    private fun waitFor464XlatReady(network: Network): String {
+        val iface = cm.getLinkProperties(network)!!.interfaceName!!
+
+        // Make a network request to listen to the specific test network.
+        val nr = NetworkRequest.Builder()
+            .clearCapabilities()
+            .addTransportType(NetworkCapabilities.TRANSPORT_TEST)
+            .setNetworkSpecifier(TestNetworkSpecifier(iface))
+            .build()
+        val testCb = TestableNetworkCallback()
+        cm.registerNetworkCallback(nr, testCb)
+
+        // Wait for the stacked address to be available.
+        testCb.eventuallyExpect<LinkPropertiesChanged> {
+            it.lp.stackedLinks.getOrNull(0)?.linkAddresses?.getOrNull(0) != null
+        }
+
+        return iface
+    }
+
+    private val Network.mtu: Int get() {
+        val lp = cm.getLinkProperties(this)!!
+        val mtuStacked = if (lp.stackedLinks[0]?.mtu != 0) lp.stackedLinks[0].mtu else DEFAULT_MTU
+        val mtuInterface = if (lp.mtu != 0) lp.mtu else DEFAULT_MTU
+        return mtuInterface.coerceAtMost(mtuStacked)
+    }
+
+    /**
+     * Verify data usage download stats with test 464xlat networks.
+     *
+     * This test starts two test networks and binds them together, the internal one is for the
+     * client to make http traffic on the test network, and the external one is for the mocked
+     * http and dns server to bind to and provide responses.
+     *
+     * After Clat setup, the client will use clat v4 address to send packets to the mocked
+     * server v4 address, which will be translated into a v6 packet by the clat daemon with
+     * NAT64 prefix learned from the mocked DNS64 response. And send to the interface.
+     *
+     * While the packets are being forwarded to the external interface, the servers will see
+     * the packets originated from the mocked v6 address, and destined to a local v6 address.
+     */
+    @Test
+    fun test464XlatTcpStats() {
+        // Wait for 464Xlat to be ready.
+        val internalInterfaceName = waitFor464XlatReady(packetBridge.internalNetwork)
+        val mtu = packetBridge.internalNetwork.mtu
+
+        val snapshotBeforeTest = StatsSnapshot(context, internalInterfaceName)
+
+        // Generate the download traffic.
+        genHttpTraffic(packetBridge.internalNetwork, uploadSize = 0L, TEST_DOWNLOAD_SIZE)
+
+        // In practice, for one way 10k download payload, the download usage is about
+        // 11222~12880 bytes, with 14~17 packets. And the upload usage is about 1279~1626 bytes
+        // with 14~17 packets, which is majorly contributed by TCP ACK packets.
+        val snapshotAfterDownload = StatsSnapshot(context, internalInterfaceName)
+        val (expectedDownloadLower, expectedDownloadUpper) = getExpectedStatsBounds(
+            TEST_DOWNLOAD_SIZE,
+            mtu,
+            DOWNLOAD
+        )
+        assertOnlyNonTaggedStatsIncreases(
+            snapshotBeforeTest,
+            snapshotAfterDownload,
+            expectedDownloadLower,
+            expectedDownloadUpper
+        )
+
+        // Generate upload traffic with tag to verify tagged data accounting as well.
+        genHttpTrafficWithTag(
+            packetBridge.internalNetwork,
+            TEST_UPLOAD_SIZE,
+            downloadSize = 0L,
+            TEST_TAG
+        )
+
+        // Verify upload data usage accounting.
+        val snapshotAfterUpload = StatsSnapshot(context, internalInterfaceName)
+        val (expectedUploadLower, expectedUploadUpper) = getExpectedStatsBounds(
+            TEST_UPLOAD_SIZE,
+            mtu,
+            UPLOAD
+        )
+        assertAllStatsIncreases(
+            snapshotAfterDownload,
+            snapshotAfterUpload,
+            expectedUploadLower,
+            expectedUploadUpper
+        )
+    }
+
+    private enum class Direction {
+        DOWNLOAD,
+        UPLOAD
+    }
+
+    private fun getExpectedStatsBounds(
+        transmittedSize: Long,
+        mtu: Int,
+        direction: Direction
+    ): Pair<BareStats, BareStats> {
+        // This is already an underestimated value since the input doesn't include TCP/IP
+        // layer overhead.
+        val txBytesLower = transmittedSize
+        // Include TCP/IP header overheads and retransmissions in the upper bound.
+        val txBytesUpper = (transmittedSize * TOLERANCE).toLong()
+        val txPacketsLower = txBytesLower / mtu + (CONSTANT_PACKET_OVERHEAD / TOLERANCE).toLong()
+        val estTransmissionPacketsUpper = ceil(txBytesUpper / mtu.toDouble()).toLong()
+        val txPacketsUpper = estTransmissionPacketsUpper +
+                (CONSTANT_PACKET_OVERHEAD * TOLERANCE).toLong()
+        // Assume ACK only sent once for the entire transmission.
+        val rxPacketsLower = 1L + (CONSTANT_PACKET_OVERHEAD / TOLERANCE).toLong()
+        // Assume ACK sent for every RX packet.
+        val rxPacketsUpper = txPacketsUpper
+        val rxBytesLower = 1L * TCP_ACK_SIZE + (CONSTANT_BYTES_OVERHEAD / TOLERANCE).toLong()
+        val rxBytesUpper = estTransmissionPacketsUpper * TCP_ACK_SIZE +
+                (CONSTANT_BYTES_OVERHEAD * TOLERANCE).toLong()
+
+        return if (direction == UPLOAD) {
+            BareStats(rxBytesLower, rxPacketsLower, txBytesLower, txPacketsLower) to
+                    BareStats(rxBytesUpper, rxPacketsUpper, txBytesUpper, txPacketsUpper)
+        } else {
+            BareStats(txBytesLower, txPacketsLower, rxBytesLower, rxPacketsLower) to
+                    BareStats(txBytesUpper, txPacketsUpper, rxBytesUpper, rxPacketsUpper)
+        }
+    }
+
+    private fun genHttpTraffic(network: Network, uploadSize: Long, downloadSize: Long) =
+        genHttpTrafficWithTag(network, uploadSize, downloadSize, NetworkStats.Bucket.TAG_NONE)
+
+    private fun genHttpTrafficWithTag(
+        network: Network,
+        uploadSize: Long,
+        downloadSize: Long,
+        tag: Int
+    ) {
+        val path = "/test_upload_download"
+        val buf = ByteArray(DEFAULT_BUFFER_SIZE)
+
+        httpServer.addResponse(
+            TestHttpServer.Request(path, NanoHTTPD.Method.POST), NanoHTTPD.Response.Status.OK,
+            content = getRandomString(downloadSize)
+        )
+        var httpConnection: HttpURLConnection? = null
+        try {
+            TrafficStats.setThreadStatsTag(tag)
+            val spec = "http://$HTTP_SERVER_NAME:${httpServer.listeningPort}$path"
+            val url = URL(spec)
+            httpConnection = network.openConnection(url) as HttpURLConnection
+            httpConnection.connectTimeout = CONNECTION_TIMEOUT_MILLIS
+            httpConnection.requestMethod = "POST"
+            httpConnection.doOutput = true
+            // Tell the server that the response should not be compressed. Otherwise, the data usage
+            // accounted will be less than expected.
+            httpConnection.setRequestProperty("Accept-Encoding", "identity")
+            // Tell the server that to close connection after this request, this is needed to
+            // prevent from reusing the same socket that has different tagging requirement.
+            httpConnection.setRequestProperty("Connection", "close")
+
+            // Send http body.
+            val outputStream = BufferedOutputStream(httpConnection.outputStream)
+            outputStream.write(getRandomString(uploadSize).toByteArray(Charset.forName("UTF-8")))
+            outputStream.close()
+            assertEquals(HTTP_OK, httpConnection.responseCode)
+
+            // Receive response from the server.
+            val inputStream = BufferedInputStream(httpConnection.getInputStream())
+            var total = 0L
+            while (true) {
+                val count = inputStream.read(buf)
+                if (count == -1) break // End-of-Stream
+                total += count
+            }
+            assertEquals(downloadSize, total)
+        } finally {
+            httpConnection?.inputStream?.close()
+            TrafficStats.clearThreadStatsTag()
+        }
+    }
+
+    // NetworkStats.Bucket cannot be written. So another class is needed to
+    // perform arithmetic operations.
+    data class BareStats(
+        val rxBytes: Long,
+        val rxPackets: Long,
+        val txBytes: Long,
+        val txPackets: Long
+    ) {
+        operator fun plus(other: BareStats): BareStats {
+            return BareStats(
+                this.rxBytes + other.rxBytes, this.rxPackets + other.rxPackets,
+                this.txBytes + other.txBytes, this.txPackets + other.txPackets
+            )
+        }
+
+        operator fun minus(other: BareStats): BareStats {
+            return BareStats(
+                this.rxBytes - other.rxBytes, this.rxPackets - other.rxPackets,
+                this.txBytes - other.txBytes, this.txPackets - other.txPackets
+            )
+        }
+
+        fun reverse(): BareStats =
+            BareStats(
+                rxBytes = txBytes,
+                rxPackets = txPackets,
+                txBytes = rxBytes,
+                txPackets = rxPackets
+            )
+
+        override fun toString(): String {
+            return "BareStats{rx/txBytes=$rxBytes/$txBytes, rx/txPackets=$rxPackets/$txPackets}"
+        }
+
+        override fun equals(other: Any?): Boolean {
+            if (this === other) return true
+            if (other !is BareStats) return false
+
+            if (rxBytes != other.rxBytes) return false
+            if (rxPackets != other.rxPackets) return false
+            if (txBytes != other.txBytes) return false
+            if (txPackets != other.txPackets) return false
+
+            return true
+        }
+
+        override fun hashCode(): Int {
+            return (rxBytes * 11 + rxPackets * 13 + txBytes * 17 + txPackets * 19).toInt()
+        }
+
+        companion object {
+            val EMPTY = BareStats(0L, 0L, 0L, 0L)
+        }
+    }
+
+    data class StatsSnapshot(val context: Context, val iface: String) {
+        val statsSummary = getNetworkSummary(iface)
+        val statsUid = getUidDetail(iface, TAG_NONE)
+        val taggedSummary = getTaggedNetworkSummary(iface, TEST_TAG)
+        val taggedUid = getUidDetail(iface, TEST_TAG)
+        val trafficStatsIface = getTrafficStatsIface(iface)
+        val trafficStatsUid = getTrafficStatsUid(Process.myUid())
+
+        private fun getUidDetail(iface: String, tag: Int): BareStats {
+            return getNetworkStatsThat(iface, tag) { nsm, template ->
+                nsm.queryDetailsForUidTagState(
+                    template, Long.MIN_VALUE, Long.MAX_VALUE,
+                    Process.myUid(), tag, Bucket.STATE_ALL
+                )
+            }
+        }
+
+        private fun getNetworkSummary(iface: String): BareStats {
+            return getNetworkStatsThat(iface, TAG_NONE) { nsm, template ->
+                nsm.querySummary(template, Long.MIN_VALUE, Long.MAX_VALUE)
+            }
+        }
+
+        private fun getTaggedNetworkSummary(iface: String, tag: Int): BareStats {
+            return getNetworkStatsThat(iface, tag) { nsm, template ->
+                nsm.queryTaggedSummary(template, Long.MIN_VALUE, Long.MAX_VALUE)
+            }
+        }
+
+        private fun getNetworkStatsThat(
+            iface: String,
+            tag: Int,
+            queryApi: (nsm: NetworkStatsManager, template: NetworkTemplate) -> NetworkStats
+        ): BareStats {
+            val nsm = context.getSystemService(NetworkStatsManager::class.java)!!
+            nsm.forceUpdate()
+            val testTemplate = NetworkTemplate.Builder(MATCH_TEST)
+                .setWifiNetworkKeys(setOf(iface)).build()
+            val stats = queryApi.invoke(nsm, testTemplate)
+            val filteredBuckets =
+                stats.buckets().filter { it.uid == Process.myUid() && it.tag == tag }
+            return filteredBuckets.fold(BareStats.EMPTY) { acc, it ->
+                acc + BareStats(
+                    it.rxBytes,
+                    it.rxPackets,
+                    it.txBytes,
+                    it.txPackets
+                )
+            }
+        }
+
+        // Helper function to iterate buckets in app.usage.NetworkStats.
+        private fun NetworkStats.buckets() = object : Iterable<NetworkStats.Bucket> {
+            override fun iterator() = object : Iterator<NetworkStats.Bucket> {
+                override operator fun hasNext() = hasNextBucket()
+                override operator fun next() =
+                    NetworkStats.Bucket().also { assertTrue(getNextBucket(it)) }
+            }
+        }
+
+        private fun getTrafficStatsIface(iface: String): BareStats = BareStats(
+            TrafficStats.getRxBytes(iface),
+            TrafficStats.getRxPackets(iface),
+            TrafficStats.getTxBytes(iface),
+            TrafficStats.getTxPackets(iface)
+        )
+
+        private fun getTrafficStatsUid(uid: Int): BareStats = BareStats(
+            TrafficStats.getUidRxBytes(uid),
+            TrafficStats.getUidRxPackets(uid),
+            TrafficStats.getUidTxBytes(uid),
+            TrafficStats.getUidTxPackets(uid)
+        )
+    }
+
+    private fun assertAllStatsIncreases(
+        before: StatsSnapshot,
+        after: StatsSnapshot,
+        lower: BareStats,
+        upper: BareStats
+    ) {
+        assertNonTaggedStatsIncreases(before, after, lower, upper)
+        assertTaggedStatsIncreases(before, after, lower, upper)
+    }
+
+    private fun assertOnlyNonTaggedStatsIncreases(
+        before: StatsSnapshot,
+        after: StatsSnapshot,
+        lower: BareStats,
+        upper: BareStats
+    ) {
+        assertNonTaggedStatsIncreases(before, after, lower, upper)
+        assertTaggedStatsEquals(before, after)
+    }
+
+    private fun assertNonTaggedStatsIncreases(
+        before: StatsSnapshot,
+        after: StatsSnapshot,
+        lower: BareStats,
+        upper: BareStats
+    ) {
+        assertInRange(
+            "Unexpected iface traffic stats",
+            after.iface,
+            before.trafficStatsIface, after.trafficStatsIface,
+            lower, upper
+        )
+        // Uid traffic stats are counted in both direction because the external network
+        // traffic is also attributed to the test uid.
+        assertInRange(
+            "Unexpected uid traffic stats",
+            after.iface,
+            before.trafficStatsUid, after.trafficStatsUid,
+            lower + lower.reverse(), upper + upper.reverse()
+        )
+        assertInRange(
+            "Unexpected non-tagged summary stats",
+            after.iface,
+            before.statsSummary, after.statsSummary,
+            lower, upper
+        )
+        assertInRange(
+            "Unexpected non-tagged uid stats",
+            after.iface,
+            before.statsUid, after.statsUid,
+            lower, upper
+        )
+    }
+
+    private fun assertTaggedStatsEquals(before: StatsSnapshot, after: StatsSnapshot) {
+        // Increment of tagged data should be zero since no tagged traffic was generated.
+        assertEquals(
+            before.taggedSummary,
+            after.taggedSummary,
+            "Unexpected tagged summary stats: ${after.iface}"
+        )
+        assertEquals(
+            before.taggedUid,
+            after.taggedUid,
+            "Unexpected tagged uid stats: ${Process.myUid()} on ${after.iface}"
+        )
+    }
+
+    private fun assertTaggedStatsIncreases(
+        before: StatsSnapshot,
+        after: StatsSnapshot,
+        lower: BareStats,
+        upper: BareStats
+    ) {
+        assertInRange(
+            "Unexpected tagged summary stats",
+            after.iface,
+            before.taggedSummary, after.taggedSummary,
+            lower,
+            upper
+        )
+        assertInRange(
+            "Unexpected tagged uid stats: ${Process.myUid()}",
+            after.iface,
+            before.taggedUid, after.taggedUid,
+            lower,
+            upper
+        )
+    }
+
+    /** Verify the given BareStats is in range [lower, upper] */
+    private fun assertInRange(
+        tag: String,
+        iface: String,
+        before: BareStats,
+        after: BareStats,
+        lower: BareStats,
+        upper: BareStats
+    ) {
+        // Passing the value after operation and the value before operation to dump the actual
+        // numbers if it fails.
+        assertTrue(checkInRange(before, after, lower, upper),
+            "$tag on $iface: $after - $before is not within range [$lower, $upper]"
+        )
+    }
+
+    private fun checkInRange(
+            before: BareStats,
+            after: BareStats,
+            lower: BareStats,
+            upper: BareStats
+    ): Boolean {
+        val value = after - before
+        return value.rxBytes in lower.rxBytes..upper.rxBytes &&
+                value.rxPackets in lower.rxPackets..upper.rxPackets &&
+                value.txBytes in lower.txBytes..upper.txBytes &&
+                value.txPackets in lower.txPackets..upper.txPackets
+    }
+
+    fun getRandomString(length: Long): String {
+        val allowedChars = ('A'..'Z') + ('a'..'z') + ('0'..'9')
+        return (1..length)
+            .map { allowedChars.random() }
+            .joinToString("")
+    }
+}
diff --git a/tests/unit/vpn-jarjar-rules.txt b/tests/unit/vpn-jarjar-rules.txt
index 1a6bddc..f74eab8 100644
--- a/tests/unit/vpn-jarjar-rules.txt
+++ b/tests/unit/vpn-jarjar-rules.txt
@@ -1,4 +1,2 @@
 # Only keep classes imported by ConnectivityServiceTest
-keep com.android.server.connectivity.Vpn
 keep com.android.server.connectivity.VpnProfileStore
-keep com.android.server.net.LockdownVpnTracker
diff --git a/thread/tests/cts/Android.bp b/thread/tests/cts/Android.bp
index 676eb0e..c1cf0a0 100644
--- a/thread/tests/cts/Android.bp
+++ b/thread/tests/cts/Android.bp
@@ -42,6 +42,7 @@
         "guava",
         "guava-android-testlib",
         "net-tests-utils",
+        "ThreadNetworkTestUtils",
         "truth",
     ],
     libs: [
diff --git a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
index 3bec36b..36ce4d5 100644
--- a/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
+++ b/thread/tests/cts/src/android/net/thread/cts/ThreadNetworkControllerTest.java
@@ -17,7 +17,6 @@
 package android.net.thread.cts;
 
 import static android.Manifest.permission.ACCESS_NETWORK_STATE;
-import static android.Manifest.permission.MANAGE_TEST_NETWORKS;
 import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_CHILD;
 import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_LEADER;
 import static android.net.thread.ThreadNetworkController.DEVICE_ROLE_ROUTER;
@@ -33,7 +32,6 @@
 
 import static androidx.test.platform.app.InstrumentationRegistry.getInstrumentation;
 
-import static com.android.testutils.TestNetworkTrackerKt.initTestNetwork;
 import static com.android.testutils.TestPermissionUtil.runAsShell;
 
 import static com.google.common.truth.Truth.assertThat;
@@ -48,7 +46,6 @@
 
 import android.content.Context;
 import android.net.ConnectivityManager;
-import android.net.LinkAddress;
 import android.net.Network;
 import android.net.NetworkCapabilities;
 import android.net.NetworkRequest;
@@ -62,7 +59,9 @@
 import android.net.thread.ThreadNetworkController.StateCallback;
 import android.net.thread.ThreadNetworkException;
 import android.net.thread.ThreadNetworkManager;
+import android.net.thread.utils.TapTestNetworkTracker;
 import android.os.Build;
+import android.os.HandlerThread;
 import android.os.OutcomeReceiver;
 
 import androidx.annotation.NonNull;
@@ -74,7 +73,6 @@
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo;
 import com.android.testutils.DevSdkIgnoreRunner;
 import com.android.testutils.FunctionalUtils.ThrowingRunnable;
-import com.android.testutils.TestNetworkTracker;
 
 import org.junit.After;
 import org.junit.Before;
@@ -110,7 +108,7 @@
     private static final int NETWORK_CALLBACK_TIMEOUT_MILLIS = 10 * 1000;
     private static final int CALLBACK_TIMEOUT_MILLIS = 1_000;
     private static final int ENABLED_TIMEOUT_MILLIS = 2_000;
-    private static final int SERVICE_DISCOVERY_TIMEOUT_MILLIS = 10 * 1000;
+    private static final int SERVICE_DISCOVERY_TIMEOUT_MILLIS = 30_000;
     private static final String MESHCOP_SERVICE_TYPE = "_meshcop._udp";
     private static final String THREAD_NETWORK_PRIVILEGED =
             "android.permission.THREAD_NETWORK_PRIVILEGED";
@@ -123,12 +121,11 @@
     private NsdManager mNsdManager;
 
     private Set<String> mGrantedPermissions;
+    private HandlerThread mHandlerThread;
+    private TapTestNetworkTracker mTestNetworkTracker;
 
     @Before
     public void setUp() throws Exception {
-
-        mGrantedPermissions = new HashSet<String>();
-        mExecutor = Executors.newSingleThreadExecutor();
         ThreadNetworkManager manager = mContext.getSystemService(ThreadNetworkManager.class);
         if (manager != null) {
             mController = manager.getAllThreadNetworkControllers().get(0);
@@ -138,20 +135,23 @@
         // tests if a feature is not available.
         assumeNotNull(mController);
 
-        setEnabledAndWait(mController, true);
-
+        mGrantedPermissions = new HashSet<String>();
+        mExecutor = Executors.newSingleThreadExecutor();
         mNsdManager = mContext.getSystemService(NsdManager.class);
+        mHandlerThread = new HandlerThread(this.getClass().getSimpleName());
+        mHandlerThread.start();
+
+        setEnabledAndWait(mController, true);
     }
 
     @After
     public void tearDown() throws Exception {
-        if (mController != null) {
-            grantPermissions(THREAD_NETWORK_PRIVILEGED);
-            CompletableFuture<Void> future = new CompletableFuture<>();
-            mController.leave(mExecutor, future::complete);
-            future.get(LEAVE_TIMEOUT_MILLIS, MILLISECONDS);
+        if (mController == null) {
+            return;
         }
         dropAllPermissions();
+        leaveAndWait(mController);
+        tearDownTestNetwork();
     }
 
     @Test
@@ -829,7 +829,7 @@
 
     @Test
     public void meshcopService_threadEnabledButNotJoined_discoveredButNoNetwork() throws Exception {
-        TestNetworkTracker testNetwork = setUpTestNetwork();
+        setUpTestNetwork();
 
         setEnabledAndWait(mController, true);
         leaveAndWait(mController);
@@ -845,13 +845,11 @@
         assertThat(txtMap.get("rv")).isNotNull();
         assertThat(txtMap.get("tv")).isNotNull();
         assertThat(txtMap.get("sb")).isNotNull();
-
-        tearDownTestNetwork(testNetwork);
     }
 
     @Test
     public void meshcopService_joinedNetwork_discoveredHasNetwork() throws Exception {
-        TestNetworkTracker testNetwork = setUpTestNetwork();
+        setUpTestNetwork();
 
         String networkName = "TestNet" + new Random().nextInt(10_000);
         joinRandomizedDatasetAndWait(mController, networkName);
@@ -872,27 +870,26 @@
         assertThat(txtMap.get("tv")).isNotNull();
         assertThat(txtMap.get("sb")).isNotNull();
         assertThat(txtMap.get("id").length).isEqualTo(16);
-
-        tearDownTestNetwork(testNetwork);
     }
 
     @Test
     public void meshcopService_threadDisabled_notDiscovered() throws Exception {
-        TestNetworkTracker testNetwork = setUpTestNetwork();
+        setUpTestNetwork();
 
         CompletableFuture<NsdServiceInfo> serviceLostFuture = new CompletableFuture<>();
         NsdManager.DiscoveryListener listener =
                 discoverForServiceLost(MESHCOP_SERVICE_TYPE, serviceLostFuture);
         setEnabledAndWait(mController, false);
-
         try {
-            serviceLostFuture.get(10_000, MILLISECONDS);
+            serviceLostFuture.get(SERVICE_DISCOVERY_TIMEOUT_MILLIS, MILLISECONDS);
+        } catch (InterruptedException | ExecutionException | TimeoutException ignored) {
+            // It's fine if the service lost event didn't show up. The service may not ever be
+            // advertised.
         } finally {
             mNsdManager.stopServiceDiscovery(listener);
         }
-        assertThrows(TimeoutException.class, () -> discoverService(MESHCOP_SERVICE_TYPE));
 
-        tearDownTestNetwork(testNetwork);
+        assertThrows(TimeoutException.class, () -> discoverService(MESHCOP_SERVICE_TYPE));
     }
 
     private static void dropAllPermissions() {
@@ -1163,14 +1160,17 @@
         }
     }
 
-    TestNetworkTracker setUpTestNetwork() {
-        return runAsShell(
-                MANAGE_TEST_NETWORKS,
-                () -> initTestNetwork(mContext, new LinkAddress("2001:db8:123::/64"), 10_000));
+    private void setUpTestNetwork() {
+        assertThat(mTestNetworkTracker).isNull();
+        mTestNetworkTracker = new TapTestNetworkTracker(mContext, mHandlerThread.getLooper());
     }
 
-    void tearDownTestNetwork(TestNetworkTracker testNetwork) {
-        runAsShell(MANAGE_TEST_NETWORKS, () -> testNetwork.teardown());
+    private void tearDownTestNetwork() throws InterruptedException {
+        if (mTestNetworkTracker != null) {
+            mTestNetworkTracker.tearDown();
+        }
+        mHandlerThread.quitSafely();
+        mHandlerThread.join();
     }
 
     private static class DefaultDiscoveryListener implements NsdManager.DiscoveryListener {
diff --git a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
index 7aaae86..e8ef346 100644
--- a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
+++ b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
@@ -374,9 +374,13 @@
         subscribeMulticastAddressAndWait(ftd2, GROUP_ADDR_SCOPE_4);
 
         mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
-        mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_4);
 
         assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd1.getOmrAddress()));
+
+        // Verify ping reply from ftd1 and ftd2 separately as the order of replies can't be
+        // predicted.
+        mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_4);
+
         assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd2.getOmrAddress()));
     }
 
@@ -405,11 +409,13 @@
         startFtdChild(ftd2);
         subscribeMulticastAddressAndWait(ftd2, GROUP_ADDR_SCOPE_5);
 
-        // Send the request twice as the order of replies from ftd1 and ftd2 are not guaranteed
-        mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
         mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
 
         assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd1.getOmrAddress()));
+
+        // Send the request twice as the order of replies from ftd1 and ftd2 are not guaranteed
+        mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
+
         assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd2.getOmrAddress()));
     }
 
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
index 4948c22..60a5f2b 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
@@ -16,7 +16,6 @@
 
 package com.android.server.thread;
 
-import static android.Manifest.permission.ACCESS_NETWORK_STATE;
 import static android.net.thread.ActiveOperationalDataset.CHANNEL_PAGE_24_GHZ;
 import static android.net.thread.ThreadNetworkController.STATE_DISABLED;
 import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
diff --git a/thread/tests/utils/Android.bp b/thread/tests/utils/Android.bp
new file mode 100644
index 0000000..24e9bb9
--- /dev/null
+++ b/thread/tests/utils/Android.bp
@@ -0,0 +1,37 @@
+//
+// Copyright (C) 2023 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+package {
+    default_team: "trendy_team_fwk_thread_network",
+    default_applicable_licenses: ["Android-Apache-2.0"],
+}
+
+java_library {
+    name: "ThreadNetworkTestUtils",
+    min_sdk_version: "30",
+    static_libs: [
+        "compatibility-device-util-axt",
+        "net-tests-utils",
+        "net-utils-device-common",
+        "net-utils-device-common-bpf",
+    ],
+    srcs: [
+        "src/**/*.java",
+    ],
+    defaults: [
+        "framework-connectivity-test-defaults",
+    ],
+}
diff --git a/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java b/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java
new file mode 100644
index 0000000..43f177d
--- /dev/null
+++ b/thread/tests/utils/src/android/net/thread/utils/TapTestNetworkTracker.java
@@ -0,0 +1,185 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package android.net.thread.utils;
+
+import static android.Manifest.permission.MANAGE_TEST_NETWORKS;
+import static android.net.InetAddresses.parseNumericAddress;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_TRUSTED;
+import static android.net.NetworkCapabilities.TRANSPORT_TEST;
+import static android.system.OsConstants.AF_INET6;
+import static android.system.OsConstants.IPPROTO_UDP;
+import static android.system.OsConstants.SOCK_DGRAM;
+
+import static com.android.testutils.RecorderCallback.CallbackEntry.LINK_PROPERTIES_CHANGED;
+import static com.android.testutils.TestPermissionUtil.runAsShell;
+
+import android.content.Context;
+import android.net.ConnectivityManager;
+import android.net.LinkAddress;
+import android.net.LinkProperties;
+import android.net.Network;
+import android.net.NetworkAgentConfig;
+import android.net.NetworkCapabilities;
+import android.net.NetworkRequest;
+import android.net.TestNetworkInterface;
+import android.net.TestNetworkManager;
+import android.net.TestNetworkSpecifier;
+import android.os.Looper;
+import android.system.ErrnoException;
+import android.system.Os;
+
+import com.android.compatibility.common.util.PollingCheck;
+import com.android.testutils.TestableNetworkAgent;
+import com.android.testutils.TestableNetworkCallback;
+
+import java.io.FileDescriptor;
+import java.io.IOException;
+import java.net.InterfaceAddress;
+import java.net.NetworkInterface;
+import java.net.SocketException;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+/** A class that can create/destroy a test network based on TAP interface. */
+public final class TapTestNetworkTracker {
+    private static final Duration TIMEOUT = Duration.ofSeconds(2);
+    private final Context mContext;
+    private final Looper mLooper;
+    private TestNetworkInterface mInterface;
+    private TestableNetworkAgent mAgent;
+    private final TestableNetworkCallback mNetworkCallback;
+    private final ConnectivityManager mConnectivityManager;
+
+    /**
+     * Constructs a {@link TapTestNetworkTracker}.
+     *
+     * <p>It creates a TAP interface (e.g. testtap0) and registers a test network using that
+     * interface. It also requests the test network by {@link ConnectivityManager#requestNetwork} so
+     * the test network won't be automatically turned down by {@link
+     * com.android.server.ConnectivityService}.
+     */
+    public TapTestNetworkTracker(Context context, Looper looper) {
+        mContext = context;
+        mLooper = looper;
+        mConnectivityManager = mContext.getSystemService(ConnectivityManager.class);
+        mNetworkCallback = new TestableNetworkCallback();
+        runAsShell(MANAGE_TEST_NETWORKS, this::setUpTestNetwork);
+    }
+
+    /** Tears down the test network. */
+    public void tearDown() {
+        runAsShell(MANAGE_TEST_NETWORKS, this::tearDownTestNetwork);
+    }
+
+    /** Returns the interface name of the test network. */
+    public String getInterfaceName() {
+        return mInterface.getInterfaceName();
+    }
+
+    private void setUpTestNetwork() throws Exception {
+        mInterface = mContext.getSystemService(TestNetworkManager.class).createTapInterface();
+
+        mConnectivityManager.requestNetwork(newNetworkRequest(), mNetworkCallback);
+
+        LinkProperties lp = new LinkProperties();
+        lp.setInterfaceName(getInterfaceName());
+        mAgent =
+                new TestableNetworkAgent(
+                        mContext,
+                        mLooper,
+                        newNetworkCapabilities(),
+                        lp,
+                        new NetworkAgentConfig.Builder().build());
+        final Network network = mAgent.register();
+        mAgent.markConnected();
+
+        PollingCheck.check(
+                "No usable address on interface",
+                TIMEOUT.toMillis(),
+                () -> hasUsableAddress(network, getInterfaceName()));
+
+        lp.setLinkAddresses(makeLinkAddresses());
+        mAgent.sendLinkProperties(lp);
+        mNetworkCallback.eventuallyExpect(
+                LINK_PROPERTIES_CHANGED,
+                TIMEOUT.toMillis(),
+                l -> !l.getLp().getAddresses().isEmpty());
+    }
+
+    private void tearDownTestNetwork() throws IOException {
+        mConnectivityManager.unregisterNetworkCallback(mNetworkCallback);
+        mAgent.unregister();
+        mInterface.getFileDescriptor().close();
+        mAgent.waitForIdle(TIMEOUT.toMillis());
+    }
+
+    private NetworkRequest newNetworkRequest() {
+        return new NetworkRequest.Builder()
+                .removeCapability(NET_CAPABILITY_TRUSTED)
+                .addTransportType(TRANSPORT_TEST)
+                .setNetworkSpecifier(new TestNetworkSpecifier(getInterfaceName()))
+                .build();
+    }
+
+    private NetworkCapabilities newNetworkCapabilities() {
+        return new NetworkCapabilities()
+                .removeCapability(NET_CAPABILITY_TRUSTED)
+                .addTransportType(TRANSPORT_TEST)
+                .setNetworkSpecifier(new TestNetworkSpecifier(getInterfaceName()));
+    }
+
+    private List<LinkAddress> makeLinkAddresses() {
+        List<LinkAddress> linkAddresses = new ArrayList<>();
+        List<InterfaceAddress> interfaceAddresses = Collections.emptyList();
+
+        try {
+            interfaceAddresses =
+                    NetworkInterface.getByName(getInterfaceName()).getInterfaceAddresses();
+        } catch (SocketException ignored) {
+            // Ignore failures when getting the addresses.
+        }
+
+        for (InterfaceAddress address : interfaceAddresses) {
+            linkAddresses.add(
+                    new LinkAddress(address.getAddress(), address.getNetworkPrefixLength()));
+        }
+
+        return linkAddresses;
+    }
+
+    private static boolean hasUsableAddress(Network network, String interfaceName) {
+        try {
+            if (NetworkInterface.getByName(interfaceName).getInterfaceAddresses().isEmpty()) {
+                return false;
+            }
+        } catch (SocketException e) {
+            return false;
+        }
+        // Check if the link-local address can be used. Address flags are not available without
+        // elevated permissions, so check that bindSocket works.
+        try {
+            FileDescriptor sock = Os.socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP);
+            network.bindSocket(sock);
+            Os.connect(sock, parseNumericAddress("ff02::fb%" + interfaceName), 12345);
+            Os.close(sock);
+        } catch (ErrnoException | IOException e) {
+            return false;
+        }
+        return true;
+    }
+}