diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index dc855c1..0fe24a2 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -9672,10 +9672,10 @@
      * interfaces.
      * Ingress discard rule is added to the address iff
      *   1. The address is not a link local address
-     *   2. The address is used by a single non-Legacy VPN interface and not used by any other
-     *      interfaces even non-VPN ones
-     * Ingress discard rule is not be added to Legacy VPN since some Legacy VPNs need to receive
-     * packet to VPN address via non-VPN interface.
+     *   2. The address is used by a single interface of VPN whose VPN type is not TYPE_VPN_LEGACY
+     *      or TYPE_VPN_OEM and the address is not used by any other interfaces even non-VPN ones
+     * Ingress discard rule is not be added to TYPE_VPN_LEGACY or TYPE_VPN_OEM VPN since these VPNs
+     * might need to receive packet to VPN address via non-VPN interface.
      * This method can be called during network disconnects, when nai has already been removed from
      * mNetworkAgentInfos.
      *
@@ -9710,8 +9710,10 @@
         // for different network.
         final Set<Pair<InetAddress, String>> ingressDiscardRules = new ArraySet<>();
         for (final NetworkAgentInfo agent : nais) {
+            final int vpnType = getVpnType(agent);
             if (!agent.isVPN() || agent.isDestroyed()
-                    || getVpnType(agent) == VpnManager.TYPE_VPN_LEGACY) {
+                    || vpnType == VpnManager.TYPE_VPN_LEGACY
+                    || vpnType == VpnManager.TYPE_VPN_OEM) {
                 continue;
             }
             final LinkProperties agentLp = (nai == agent) ? lp : agent.linkProperties;
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt b/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
index 1ae77e5..77b06b2 100644
--- a/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/CSIngressDiscardRuleTests.kt
@@ -26,6 +26,7 @@
 import android.net.NetworkCapabilities.TRANSPORT_VPN
 import android.net.NetworkCapabilities.TRANSPORT_WIFI
 import android.net.NetworkRequest
+import android.net.VpnManager.TYPE_VPN_OEM
 import android.net.VpnManager.TYPE_VPN_SERVICE
 import android.net.VpnManager.TYPE_VPN_LEGACY
 import android.net.VpnTransportInfo
@@ -50,11 +51,10 @@
 private const val TIMEOUT_MS = 1_000L
 private const val LONG_TIMEOUT_MS = 5_000
 
-private fun vpnNc(legacyVpn: Boolean = false) = NetworkCapabilities.Builder().apply {
+private fun vpnNc(vpnType: Int = TYPE_VPN_SERVICE) = NetworkCapabilities.Builder().apply {
     addTransportType(TRANSPORT_VPN)
     removeCapability(NET_CAPABILITY_NOT_VPN)
     addCapability(NET_CAPABILITY_NOT_VCN_MANAGED)
-    val vpnType = if (legacyVpn) { TYPE_VPN_LEGACY } else { TYPE_VPN_SERVICE }
     setTransportInfo(
             VpnTransportInfo(
                     vpnType,
@@ -313,18 +313,37 @@
         verify(bpfNetMaps, never()).setIngressDiscardRule(any(), any())
     }
 
-    @Test
-    fun testVpnIngressDiscardRule_LegacyVpn() {
+    fun doTestVpnIngressDiscardRule_VpnType(vpnType: Int, expectAddRule: Boolean) {
         val nr = nr(TRANSPORT_VPN)
         val cb = TestableNetworkCallback()
         cm.registerNetworkCallback(nr, cb)
-        val nc = vpnNc(legacyVpn = true)
+        val nc = vpnNc(vpnType)
         val lp = lp(VPN_IFNAME, IPV6_LINK_ADDRESS, LOCAL_IPV6_LINK_ADDRRESS)
         val agent = Agent(nc = nc, lp = lp)
         agent.connect()
         cb.expectAvailableCallbacks(agent.network, validated = false)
 
+        if (expectAddRule) {
+            verify(bpfNetMaps).setIngressDiscardRule(IPV6_ADDRESS, VPN_IFNAME)
+        } else {
+            verify(bpfNetMaps, never()).setIngressDiscardRule(any(), any())
+        }
+    }
+
+    @Test
+    fun testVpnIngressDiscardRule_ServiceVpn() {
+        doTestVpnIngressDiscardRule_VpnType(TYPE_VPN_SERVICE, expectAddRule = true)
+    }
+
+    @Test
+    fun testVpnIngressDiscardRule_LegacyVpn() {
         // IngressDiscardRule should not be added to Legacy VPN
-        verify(bpfNetMaps, never()).setIngressDiscardRule(any(), any())
+        doTestVpnIngressDiscardRule_VpnType(TYPE_VPN_LEGACY, expectAddRule = false)
+    }
+
+    @Test
+    fun testVpnIngressDiscardRule_OemVpn() {
+        // IngressDiscardRule should not be added to OEM VPN
+        doTestVpnIngressDiscardRule_VpnType(TYPE_VPN_OEM, expectAddRule = false)
     }
 }
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 622b482..b621a6a 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -238,7 +238,11 @@
         mInfraIfController = infraIfController;
         mUpstreamNetworkRequest = newUpstreamNetworkRequest();
         mNetworkToInterface = new HashMap<Network, String>();
-        mBorderRouterConfig = new BorderRouterConfiguration();
+        mBorderRouterConfig =
+                new BorderRouterConfiguration.Builder()
+                        .setIsBorderRoutingEnabled(true)
+                        .setInfraInterfaceName(null)
+                        .build();
         mPersistentSettings = persistentSettings;
         mNsdPublisher = nsdPublisher;
         mUserManager = userManager;
@@ -1228,40 +1232,54 @@
         }
     }
 
-    private void enableBorderRouting(String infraIfName) {
-        if (mBorderRouterConfig.isBorderRoutingEnabled
-                && infraIfName.equals(mBorderRouterConfig.infraInterfaceName)) {
+    private void configureBorderRouter(BorderRouterConfiguration borderRouterConfig) {
+        if (mBorderRouterConfig.equals(borderRouterConfig)) {
             return;
         }
-        Log.i(TAG, "Enable border routing on AIL: " + infraIfName);
+        Log.i(
+                TAG,
+                "Configuring Border Router: " + mBorderRouterConfig + " -> " + borderRouterConfig);
+        mBorderRouterConfig = borderRouterConfig;
+        ParcelFileDescriptor infraIcmp6Socket = null;
+        if (mBorderRouterConfig.infraInterfaceName != null) {
+            try {
+                infraIcmp6Socket =
+                        mInfraIfController.createIcmp6Socket(
+                                mBorderRouterConfig.infraInterfaceName);
+            } catch (IOException e) {
+                Log.i(TAG, "Failed to create ICMPv6 socket on infra network interface", e);
+            }
+        }
         try {
-            mBorderRouterConfig.infraInterfaceName = infraIfName;
-            mBorderRouterConfig.isBorderRoutingEnabled = true;
-            ParcelFileDescriptor infraIcmp6Socket =
-                    mInfraIfController.createIcmp6Socket(infraIfName);
             getOtDaemon()
                     .configureBorderRouter(
                             mBorderRouterConfig,
                             infraIcmp6Socket,
                             new ConfigureBorderRouterStatusReceiver());
-        } catch (RemoteException | IOException | ThreadNetworkException e) {
-            Log.w(TAG, "Failed to enable border routing", e);
+        } catch (RemoteException | ThreadNetworkException e) {
+            Log.w(TAG, "Failed to configure border router " + mBorderRouterConfig, e);
         }
     }
 
+    private void enableBorderRouting(String infraIfName) {
+        BorderRouterConfiguration borderRouterConfig =
+                newBorderRouterConfigBuilder(mBorderRouterConfig)
+                        .setIsBorderRoutingEnabled(true)
+                        .setInfraInterfaceName(infraIfName)
+                        .build();
+        Log.i(TAG, "Enable border routing on AIL: " + infraIfName);
+        configureBorderRouter(borderRouterConfig);
+    }
+
     private void disableBorderRouting() {
         mUpstreamNetwork = null;
-        mBorderRouterConfig.infraInterfaceName = null;
-        mBorderRouterConfig.isBorderRoutingEnabled = false;
-        try {
-            getOtDaemon()
-                    .configureBorderRouter(
-                            mBorderRouterConfig,
-                            null /* infraIcmp6Socket */,
-                            new ConfigureBorderRouterStatusReceiver());
-        } catch (RemoteException | ThreadNetworkException e) {
-            Log.w(TAG, "Failed to disable border routing", e);
-        }
+        BorderRouterConfiguration borderRouterConfig =
+                newBorderRouterConfigBuilder(mBorderRouterConfig)
+                        .setIsBorderRoutingEnabled(false)
+                        .setInfraInterfaceName(null)
+                        .build();
+        Log.i(TAG, "Disabling border routing");
+        configureBorderRouter(borderRouterConfig);
     }
 
     private void handleThreadInterfaceStateChanged(boolean isUp) {
@@ -1362,6 +1380,13 @@
         return builder.build();
     }
 
+    private static BorderRouterConfiguration.Builder newBorderRouterConfigBuilder(
+            BorderRouterConfiguration brConfig) {
+        return new BorderRouterConfiguration.Builder()
+                .setIsBorderRoutingEnabled(brConfig.isBorderRoutingEnabled)
+                .setInfraInterfaceName(brConfig.infraInterfaceName);
+    }
+
     private static final class CallbackMetadata {
         private static long gId = 0;
 
diff --git a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
index b6d9aa3..9e8dc3a 100644
--- a/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
+++ b/thread/tests/integration/src/android/net/thread/BorderRoutingTest.java
@@ -17,17 +17,20 @@
 package android.net.thread;
 
 import static android.Manifest.permission.MANAGE_TEST_NETWORKS;
+import static android.net.InetAddresses.parseNumericAddress;
 import static android.net.thread.utils.IntegrationTestUtils.DEFAULT_DATASET;
 import static android.net.thread.utils.IntegrationTestUtils.getIpv6LinkAddresses;
+import static android.net.thread.utils.IntegrationTestUtils.isExpectedIcmpv4Packet;
 import static android.net.thread.utils.IntegrationTestUtils.isExpectedIcmpv6Packet;
-import static android.net.thread.utils.IntegrationTestUtils.isFromIpv6Source;
+import static android.net.thread.utils.IntegrationTestUtils.isFrom;
 import static android.net.thread.utils.IntegrationTestUtils.isInMulticastGroup;
-import static android.net.thread.utils.IntegrationTestUtils.isToIpv6Destination;
+import static android.net.thread.utils.IntegrationTestUtils.isTo;
 import static android.net.thread.utils.IntegrationTestUtils.joinNetworkAndWaitForOmr;
 import static android.net.thread.utils.IntegrationTestUtils.newPacketReader;
 import static android.net.thread.utils.IntegrationTestUtils.pollForPacket;
 import static android.net.thread.utils.IntegrationTestUtils.sendUdpMessage;
 import static android.net.thread.utils.IntegrationTestUtils.waitFor;
+import static android.system.OsConstants.ICMP_ECHO;
 
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ECHO_REPLY_TYPE;
 import static com.android.net.module.util.NetworkStackConstants.ICMPV6_ECHO_REQUEST_TYPE;
@@ -44,11 +47,11 @@
 import static java.util.Objects.requireNonNull;
 
 import android.content.Context;
-import android.net.InetAddresses;
 import android.net.IpPrefix;
 import android.net.LinkAddress;
 import android.net.LinkProperties;
 import android.net.MacAddress;
+import android.net.RouteInfo;
 import android.net.thread.utils.FullThreadDevice;
 import android.net.thread.utils.InfraNetworkDevice;
 import android.net.thread.utils.OtDaemonController;
@@ -74,7 +77,9 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
+import java.net.Inet4Address;
 import java.net.Inet6Address;
+import java.net.InetAddress;
 import java.time.Duration;
 import java.util.ArrayList;
 import java.util.List;
@@ -89,11 +94,14 @@
     private static final String TAG = BorderRoutingTest.class.getSimpleName();
     private static final int NUM_FTD = 2;
     private static final Inet6Address GROUP_ADDR_SCOPE_5 =
-            (Inet6Address) InetAddresses.parseNumericAddress("ff05::1234");
+            (Inet6Address) parseNumericAddress("ff05::1234");
     private static final Inet6Address GROUP_ADDR_SCOPE_4 =
-            (Inet6Address) InetAddresses.parseNumericAddress("ff04::1234");
+            (Inet6Address) parseNumericAddress("ff04::1234");
     private static final Inet6Address GROUP_ADDR_SCOPE_3 =
-            (Inet6Address) InetAddresses.parseNumericAddress("ff03::1234");
+            (Inet6Address) parseNumericAddress("ff03::1234");
+    private static final Inet4Address IPV4_SERVER_ADDR =
+            (Inet4Address) parseNumericAddress("8.8.8.8");
+    private static final String NAT64_CIDR = "192.168.255.0/24";
 
     @Rule public final ThreadFeatureCheckerRule mThreadRule = new ThreadFeatureCheckerRule();
 
@@ -165,7 +173,7 @@
         mInfraDevice.sendEchoRequest(ftd.getOmrAddress());
 
         // Infra device receives an echo reply sent by FTD.
-        assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
+        assertNotNull(pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
     }
 
     @Test
@@ -186,7 +194,7 @@
 
         mInfraDevice.sendEchoRequest(ftd.getOmrAddress());
 
-        assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
+        assertNotNull(pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
     }
 
     @Test
@@ -213,7 +221,7 @@
 
             mInfraDevice.sendEchoRequest(ftd.getOmrAddress());
 
-            assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftdOmr));
+            assertNotNull(pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftdOmr));
         } finally {
             runAsShell(MANAGE_TEST_NETWORKS, () -> oldInfraNetworkTracker.teardown());
         }
@@ -322,7 +330,7 @@
 
         mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
 
-        assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
+        assertNotNull(pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
     }
 
     @Test
@@ -354,7 +362,7 @@
 
         mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_3);
 
-        assertNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
+        assertNull(pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
     }
 
     @Test
@@ -375,7 +383,7 @@
 
         mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_4);
 
-        assertNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
+        assertNull(pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd.getOmrAddress()));
     }
 
     @Test
@@ -405,13 +413,15 @@
 
         mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
 
-        assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd1.getOmrAddress()));
+        assertNotNull(
+                pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd1.getOmrAddress()));
 
         // Verify ping reply from ftd1 and ftd2 separately as the order of replies can't be
         // predicted.
         mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_4);
 
-        assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd2.getOmrAddress()));
+        assertNotNull(
+                pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd2.getOmrAddress()));
     }
 
     @Test
@@ -441,12 +451,14 @@
 
         mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
 
-        assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd1.getOmrAddress()));
+        assertNotNull(
+                pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd1.getOmrAddress()));
 
         // Send the request twice as the order of replies from ftd1 and ftd2 are not guaranteed
         mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
 
-        assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd2.getOmrAddress()));
+        assertNotNull(
+                pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftd2.getOmrAddress()));
     }
 
     @Test
@@ -469,9 +481,11 @@
         ftd.ping(GROUP_ADDR_SCOPE_4);
 
         assertNotNull(
-                pollForPacketOnInfraNetwork(ICMPV6_ECHO_REQUEST_TYPE, ftdOmr, GROUP_ADDR_SCOPE_5));
+                pollForIcmpPacketOnInfraNetwork(
+                        ICMPV6_ECHO_REQUEST_TYPE, ftdOmr, GROUP_ADDR_SCOPE_5));
         assertNotNull(
-                pollForPacketOnInfraNetwork(ICMPV6_ECHO_REQUEST_TYPE, ftdOmr, GROUP_ADDR_SCOPE_4));
+                pollForIcmpPacketOnInfraNetwork(
+                        ICMPV6_ECHO_REQUEST_TYPE, ftdOmr, GROUP_ADDR_SCOPE_4));
     }
 
     @Test
@@ -493,7 +507,7 @@
         ftd.ping(GROUP_ADDR_SCOPE_3);
 
         assertNull(
-                pollForPacketOnInfraNetwork(
+                pollForIcmpPacketOnInfraNetwork(
                         ICMPV6_ECHO_REQUEST_TYPE, ftd.getOmrAddress(), GROUP_ADDR_SCOPE_3));
     }
 
@@ -517,7 +531,8 @@
         ftd.ping(GROUP_ADDR_SCOPE_4, ftdLla);
 
         assertNull(
-                pollForPacketOnInfraNetwork(ICMPV6_ECHO_REQUEST_TYPE, ftdLla, GROUP_ADDR_SCOPE_4));
+                pollForIcmpPacketOnInfraNetwork(
+                        ICMPV6_ECHO_REQUEST_TYPE, ftdLla, GROUP_ADDR_SCOPE_4));
     }
 
     @Test
@@ -541,7 +556,7 @@
             ftd.ping(GROUP_ADDR_SCOPE_4, ftdMla);
 
             assertNull(
-                    pollForPacketOnInfraNetwork(
+                    pollForIcmpPacketOnInfraNetwork(
                             ICMPV6_ECHO_REQUEST_TYPE, ftdMla, GROUP_ADDR_SCOPE_4));
         }
     }
@@ -572,7 +587,7 @@
 
         mInfraDevice.sendEchoRequest(GROUP_ADDR_SCOPE_5);
 
-        assertNotNull(pollForPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftdOmr));
+        assertNotNull(pollForIcmpPacketOnInfraNetwork(ICMPV6_ECHO_REPLY_TYPE, ftdOmr));
     }
 
     @Test
@@ -600,18 +615,40 @@
         ftd.ping(GROUP_ADDR_SCOPE_4);
 
         assertNotNull(
-                pollForPacketOnInfraNetwork(ICMPV6_ECHO_REQUEST_TYPE, ftdOmr, GROUP_ADDR_SCOPE_4));
+                pollForIcmpPacketOnInfraNetwork(
+                        ICMPV6_ECHO_REQUEST_TYPE, ftdOmr, GROUP_ADDR_SCOPE_4));
+    }
+
+    @Test
+    public void nat64_threadDevicePingIpv4InfraDevice_outboundPacketIsForwarded() throws Exception {
+        FullThreadDevice ftd = mFtds.get(0);
+        joinNetworkAndWaitForOmr(ftd, DEFAULT_DATASET);
+        // TODO: enable NAT64 via ThreadNetworkController API instead of ot-ctl
+        mOtCtl.setNat64Cidr(NAT64_CIDR);
+        mOtCtl.setNat64Enabled(true);
+        waitFor(() -> mOtCtl.hasNat64PrefixInNetdata(), Duration.ofSeconds(10));
+
+        ftd.ping(IPV4_SERVER_ADDR);
+
+        assertNotNull(pollForIcmpPacketOnInfraNetwork(ICMP_ECHO, null, IPV4_SERVER_ADDR));
     }
 
     private void setUpInfraNetwork() throws Exception {
+        LinkProperties lp = new LinkProperties();
+        // NAT64 feature requires the infra network to have an IPv4 default route.
+        lp.addRoute(
+                new RouteInfo(
+                        new IpPrefix("0.0.0.0/0") /* destination */,
+                        null /* gateway */,
+                        null,
+                        RouteInfo.RTN_UNICAST,
+                        1500 /* mtu */));
         mInfraNetworkTracker =
                 runAsShell(
                         MANAGE_TEST_NETWORKS,
-                        () ->
-                                initTestNetwork(
-                                        mContext, new LinkProperties(), 5000 /* timeoutMs */));
-        mController.setTestNetworkAsUpstreamAndWait(
-                mInfraNetworkTracker.getTestIface().getInterfaceName());
+                        () -> initTestNetwork(mContext, lp, 5000 /* timeoutMs */));
+        String infraNetworkName = mInfraNetworkTracker.getTestIface().getInterfaceName();
+        mController.setTestNetworkAsUpstreamAndWait(infraNetworkName);
     }
 
     private void tearDownInfraNetwork() {
@@ -648,20 +685,28 @@
         assertInfraLinkMemberOfGroup(address);
     }
 
-    private byte[] pollForPacketOnInfraNetwork(int type, Inet6Address srcAddress) {
-        return pollForPacketOnInfraNetwork(type, srcAddress, null);
+    private byte[] pollForIcmpPacketOnInfraNetwork(int type, InetAddress srcAddress) {
+        return pollForIcmpPacketOnInfraNetwork(type, srcAddress, null /* destAddress */);
     }
 
-    private byte[] pollForPacketOnInfraNetwork(
-            int type, Inet6Address srcAddress, Inet6Address destAddress) {
-        Predicate<byte[]> filter;
-        filter =
+    private byte[] pollForIcmpPacketOnInfraNetwork(
+            int type, InetAddress srcAddress, InetAddress destAddress) {
+        if (srcAddress == null && destAddress == null) {
+            throw new IllegalArgumentException("srcAddress and destAddress cannot be both null");
+        }
+        if (srcAddress != null && destAddress != null) {
+            if ((srcAddress instanceof Inet4Address) != (destAddress instanceof Inet4Address)) {
+                throw new IllegalArgumentException(
+                        "srcAddress and destAddress must be both IPv4 or both IPv6");
+            }
+        }
+        boolean isIpv4 =
+                (srcAddress instanceof Inet4Address) || (destAddress instanceof Inet4Address);
+        final Predicate<byte[]> filter =
                 p ->
-                        (isExpectedIcmpv6Packet(p, type)
-                                && (srcAddress == null ? true : isFromIpv6Source(p, srcAddress))
-                                && (destAddress == null
-                                        ? true
-                                        : isToIpv6Destination(p, destAddress)));
+                        (isIpv4 ? isExpectedIcmpv4Packet(p, type) : isExpectedIcmpv6Packet(p, type))
+                                && (srcAddress == null || isFrom(p, srcAddress))
+                                && (destAddress == null || isTo(p, destAddress));
         return pollForPacket(mInfraNetworkReader, filter);
     }
 }
diff --git a/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java b/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
index 8440bbc..083a841 100644
--- a/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
+++ b/thread/tests/integration/src/android/net/thread/utils/FullThreadDevice.java
@@ -417,7 +417,7 @@
         executeCommand("ipmaddr add " + address.getHostAddress());
     }
 
-    public void ping(Inet6Address address, Inet6Address source) {
+    public void ping(InetAddress address, Inet6Address source) {
         ping(
                 address,
                 source,
@@ -428,7 +428,7 @@
                 PING_TIMEOUT_0_1_SECOND);
     }
 
-    public void ping(Inet6Address address) {
+    public void ping(InetAddress address) {
         ping(
                 address,
                 null,
@@ -440,7 +440,7 @@
     }
 
     /** Returns the number of ping reply packets received. */
-    public int ping(Inet6Address address, int count) {
+    public int ping(InetAddress address, int count) {
         List<String> output =
                 ping(
                         address,
@@ -454,7 +454,7 @@
     }
 
     private List<String> ping(
-            Inet6Address address,
+            InetAddress address,
             Inet6Address source,
             int size,
             int count,
diff --git a/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.java b/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.java
index 7b0c415..82e9332 100644
--- a/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.java
+++ b/thread/tests/integration/src/android/net/thread/utils/IntegrationTestUtils.java
@@ -16,6 +16,7 @@
 package android.net.thread.utils;
 
 import static android.net.NetworkCapabilities.NET_CAPABILITY_LOCAL_NETWORK;
+import static android.system.OsConstants.IPPROTO_ICMP;
 import static android.system.OsConstants.IPPROTO_ICMPV6;
 
 import static com.android.compatibility.common.util.SystemUtil.runShellCommandOrThrow;
@@ -49,7 +50,9 @@
 import androidx.test.core.app.ApplicationProvider;
 
 import com.android.net.module.util.Struct;
+import com.android.net.module.util.structs.Icmpv4Header;
 import com.android.net.module.util.structs.Icmpv6Header;
+import com.android.net.module.util.structs.Ipv4Header;
 import com.android.net.module.util.structs.Ipv6Header;
 import com.android.net.module.util.structs.PrefixInformationOption;
 import com.android.net.module.util.structs.RaHeader;
@@ -62,6 +65,7 @@
 import java.io.IOException;
 import java.net.DatagramPacket;
 import java.net.DatagramSocket;
+import java.net.Inet4Address;
 import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
@@ -192,16 +196,36 @@
         return null;
     }
 
-    /** Returns {@code true} if {@code packet} is an ICMPv6 packet of given {@code type}. */
-    public static boolean isExpectedIcmpv6Packet(byte[] packet, int type) {
-        if (packet == null) {
+    /** Returns {@code true} if {@code packet} is an ICMPv4 packet of given {@code type}. */
+    public static boolean isExpectedIcmpv4Packet(byte[] packet, int type) {
+        ByteBuffer buf = makeByteBuffer(packet);
+        Ipv4Header header = extractIpv4Header(buf);
+        if (header == null) {
             return false;
         }
-        ByteBuffer buf = ByteBuffer.wrap(packet);
+        if (header.protocol != (byte) IPPROTO_ICMP) {
+            return false;
+        }
         try {
-            if (Struct.parse(Ipv6Header.class, buf).nextHeader != (byte) IPPROTO_ICMPV6) {
-                return false;
-            }
+            return Struct.parse(Icmpv4Header.class, buf).type == (short) type;
+        } catch (IllegalArgumentException ignored) {
+            // It's fine that the passed in packet is malformed because it's could be sent
+            // by anybody.
+        }
+        return false;
+    }
+
+    /** Returns {@code true} if {@code packet} is an ICMPv6 packet of given {@code type}. */
+    public static boolean isExpectedIcmpv6Packet(byte[] packet, int type) {
+        ByteBuffer buf = makeByteBuffer(packet);
+        Ipv6Header header = extractIpv6Header(buf);
+        if (header == null) {
+            return false;
+        }
+        if (header.nextHeader != (byte) IPPROTO_ICMPV6) {
+            return false;
+        }
+        try {
             return Struct.parse(Icmpv6Header.class, buf).type == (short) type;
         } catch (IllegalArgumentException ignored) {
             // It's fine that the passed in packet is malformed because it's could be sent
@@ -210,32 +234,66 @@
         return false;
     }
 
-    public static boolean isFromIpv6Source(byte[] packet, Inet6Address src) {
-        if (packet == null) {
-            return false;
-        }
-        ByteBuffer buf = ByteBuffer.wrap(packet);
-        try {
-            return Struct.parse(Ipv6Header.class, buf).srcIp.equals(src);
-        } catch (IllegalArgumentException ignored) {
-            // It's fine that the passed in packet is malformed because it's could be sent
-            // by anybody.
+    public static boolean isFrom(byte[] packet, InetAddress src) {
+        if (src instanceof Inet4Address) {
+            return isFromIpv4Source(packet, (Inet4Address) src);
+        } else if (src instanceof Inet6Address) {
+            return isFromIpv6Source(packet, (Inet6Address) src);
         }
         return false;
     }
 
-    public static boolean isToIpv6Destination(byte[] packet, Inet6Address dest) {
-        if (packet == null) {
-            return false;
+    public static boolean isTo(byte[] packet, InetAddress dest) {
+        if (dest instanceof Inet4Address) {
+            return isToIpv4Destination(packet, (Inet4Address) dest);
+        } else if (dest instanceof Inet6Address) {
+            return isToIpv6Destination(packet, (Inet6Address) dest);
         }
-        ByteBuffer buf = ByteBuffer.wrap(packet);
+        return false;
+    }
+
+    private static boolean isFromIpv4Source(byte[] packet, Inet4Address src) {
+        Ipv4Header header = extractIpv4Header(makeByteBuffer(packet));
+        return header != null && header.srcIp.equals(src);
+    }
+
+    private static boolean isFromIpv6Source(byte[] packet, Inet6Address src) {
+        Ipv6Header header = extractIpv6Header(makeByteBuffer(packet));
+        return header != null && header.srcIp.equals(src);
+    }
+
+    private static boolean isToIpv4Destination(byte[] packet, Inet4Address dest) {
+        Ipv4Header header = extractIpv4Header(makeByteBuffer(packet));
+        return header != null && header.dstIp.equals(dest);
+    }
+
+    private static boolean isToIpv6Destination(byte[] packet, Inet6Address dest) {
+        Ipv6Header header = extractIpv6Header(makeByteBuffer(packet));
+        return header != null && header.dstIp.equals(dest);
+    }
+
+    private static ByteBuffer makeByteBuffer(byte[] packet) {
+        return packet == null ? null : ByteBuffer.wrap(packet);
+    }
+
+    private static Ipv4Header extractIpv4Header(ByteBuffer buf) {
         try {
-            return Struct.parse(Ipv6Header.class, buf).dstIp.equals(dest);
+            return Struct.parse(Ipv4Header.class, buf);
         } catch (IllegalArgumentException ignored) {
             // It's fine that the passed in packet is malformed because it's could be sent
             // by anybody.
         }
-        return false;
+        return null;
+    }
+
+    private static Ipv6Header extractIpv6Header(ByteBuffer buf) {
+        try {
+            return Struct.parse(Ipv6Header.class, buf);
+        } catch (IllegalArgumentException ignored) {
+            // It's fine that the passed in packet is malformed because it's could be sent
+            // by anybody.
+        }
+        return null;
     }
 
     /** Returns the Prefix Information Options (PIO) extracted from an ICMPv6 RA message. */
diff --git a/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java b/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
index b3175fd..15a3f5c 100644
--- a/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
+++ b/thread/tests/integration/src/android/net/thread/utils/OtDaemonController.java
@@ -105,6 +105,29 @@
         return prefixes.isEmpty() ? null : prefixes.get(0);
     }
 
+    /** Enables/Disables NAT64 feature. */
+    public void setNat64Enabled(boolean enabled) {
+        executeCommand("nat64 " + (enabled ? "enable" : "disable"));
+    }
+
+    /** Sets the NAT64 CIDR. */
+    public void setNat64Cidr(String cidr) {
+        executeCommand("nat64 cidr " + cidr);
+    }
+
+    /** Returns whether there's a NAT64 prefix in network data */
+    public boolean hasNat64PrefixInNetdata() {
+        // Example (in the 'Routes' section):
+        // fdb2:bae3:5b59:2:0:0::/96 sn low c000
+        List<String> outputLines = executeCommandAndParse("netdata show");
+        for (String line : outputLines) {
+            if (line.contains(" sn")) {
+                return true;
+            }
+        }
+        return false;
+    }
+
     public String executeCommand(String cmd) {
         return SystemUtil.runShellCommand(OT_CTL + " " + cmd);
     }
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
index df1a65b..be32764 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
@@ -17,6 +17,11 @@
 package com.android.server.thread;
 
 import static android.Manifest.permission.NETWORK_SETTINGS;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_INTERNET;
+import static android.net.NetworkCapabilities.NET_CAPABILITY_NOT_VPN;
+import static android.net.NetworkCapabilities.TRANSPORT_ETHERNET;
+import static android.net.NetworkCapabilities.TRANSPORT_THREAD;
+import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
 import static android.net.thread.ActiveOperationalDataset.CHANNEL_PAGE_24_GHZ;
 import static android.net.thread.ThreadNetworkController.STATE_DISABLED;
 import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
@@ -39,6 +44,7 @@
 import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.clearInvocations;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doNothing;
@@ -58,7 +64,6 @@
 import android.content.res.Resources;
 import android.net.ConnectivityManager;
 import android.net.NetworkAgent;
-import android.net.NetworkCapabilities;
 import android.net.NetworkProvider;
 import android.net.NetworkRequest;
 import android.net.thread.ActiveOperationalDataset;
@@ -746,6 +751,30 @@
     }
 
     @Test
+    public void initialize_upstreamNetworkRequestHasCertainTransportTypesAndCapabilities() {
+        mService.initialize();
+        mTestLooper.dispatchAll();
+
+        ArgumentCaptor<NetworkRequest> networkRequestCaptor =
+                ArgumentCaptor.forClass(NetworkRequest.class);
+        verify(mMockConnectivityManager, atLeastOnce())
+                .registerNetworkCallback(
+                        networkRequestCaptor.capture(),
+                        any(ConnectivityManager.NetworkCallback.class),
+                        any(Handler.class));
+        List<NetworkRequest> upstreamNetworkRequests =
+                networkRequestCaptor.getAllValues().stream()
+                        .filter(nr -> !nr.hasTransport(TRANSPORT_THREAD))
+                        .toList();
+        assertThat(upstreamNetworkRequests.size()).isEqualTo(1);
+        NetworkRequest upstreamNetworkRequest = upstreamNetworkRequests.get(0);
+        assertThat(upstreamNetworkRequest.hasTransport(TRANSPORT_WIFI)).isTrue();
+        assertThat(upstreamNetworkRequest.hasTransport(TRANSPORT_ETHERNET)).isTrue();
+        assertThat(upstreamNetworkRequest.hasCapability(NET_CAPABILITY_NOT_VPN)).isTrue();
+        assertThat(upstreamNetworkRequest.hasCapability(NET_CAPABILITY_INTERNET)).isTrue();
+    }
+
+    @Test
     public void setTestNetworkAsUpstream_upstreamNetworkRequestAlwaysDisallowsVpn() {
         mService.initialize();
         mTestLooper.dispatchAll();
@@ -768,10 +797,8 @@
         NetworkRequest networkRequest1 = networkRequestCaptor.getAllValues().get(0);
         NetworkRequest networkRequest2 = networkRequestCaptor.getAllValues().get(1);
         assertThat(networkRequest1.getNetworkSpecifier()).isNotNull();
-        assertThat(networkRequest1.hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN))
-                .isTrue();
+        assertThat(networkRequest1.hasCapability(NET_CAPABILITY_NOT_VPN)).isTrue();
         assertThat(networkRequest2.getNetworkSpecifier()).isNull();
-        assertThat(networkRequest2.hasCapability(NetworkCapabilities.NET_CAPABILITY_NOT_VPN))
-                .isTrue();
+        assertThat(networkRequest2.hasCapability(NET_CAPABILITY_NOT_VPN)).isTrue();
     }
 }
