Merge "call synchronizeKernelRCU() after writing false to disable"
diff --git a/TEST_MAPPING b/TEST_MAPPING
index 4c50bee..d2f6d6a 100644
--- a/TEST_MAPPING
+++ b/TEST_MAPPING
@@ -97,7 +97,13 @@
     },
     // Runs both NetHttpTests and CtsNetHttpTestCases
     {
-      "name": "NetHttpCoverageTests"
+      "name": "NetHttpCoverageTests",
+      "options": [
+        {
+          // These sometimes take longer than 1 min which is the presubmit timeout
+          "exclude-annotation": "androidx.test.filters.LargeTest"
+        }
+      ]
     }
   ],
   "postsubmit": [
@@ -211,7 +217,13 @@
       "name": "libnetworkstats_test[CaptivePortalLoginGoogle.apk+NetworkStackGoogle.apk+com.google.android.resolv.apex+com.google.android.tethering.apex]"
     },
     {
-      "name": "NetHttpCoverageTests[CaptivePortalLoginGoogle.apk+NetworkStackGoogle.apk+com.google.android.resolv.apex+com.google.android.tethering.apex]"
+      "name": "NetHttpCoverageTests[CaptivePortalLoginGoogle.apk+NetworkStackGoogle.apk+com.google.android.resolv.apex+com.google.android.tethering.apex]",
+      "options": [
+        {
+          // These sometimes take longer than 1 min which is the presubmit timeout
+          "exclude-annotation": "androidx.test.filters.LargeTest"
+        }
+      ]
     }
   ],
   "mainline-postsubmit": [
diff --git a/Tethering/tests/mts/src/android/tethering/mts/MtsEthernetTetheringTest.java b/Tethering/tests/mts/src/android/tethering/mts/MtsEthernetTetheringTest.java
index cb57d13..c2bc812 100644
--- a/Tethering/tests/mts/src/android/tethering/mts/MtsEthernetTetheringTest.java
+++ b/Tethering/tests/mts/src/android/tethering/mts/MtsEthernetTetheringTest.java
@@ -80,8 +80,8 @@
     // Per RX UDP packet size: iphdr (20) + udphdr (8) + payload (2) = 30 bytes.
     private static final int RX_UDP_PACKET_SIZE = 30;
     private static final int RX_UDP_PACKET_COUNT = 456;
-    // Per TX UDP packet size: ethhdr (14) + iphdr (20) + udphdr (8) + payload (2) = 44 bytes.
-    private static final int TX_UDP_PACKET_SIZE = 44;
+    // Per TX UDP packet size: iphdr (20) + udphdr (8) + payload (2) = 30 bytes.
+    private static final int TX_UDP_PACKET_SIZE = 30;
     private static final int TX_UDP_PACKET_COUNT = 123;
 
     private static final String DUMPSYS_TETHERING_RAWMAP_ARG = "bpfRawMap";
diff --git a/bpf_progs/offload.c b/bpf_progs/offload.c
index a8612df..56ace19 100644
--- a/bpf_progs/offload.c
+++ b/bpf_progs/offload.c
@@ -232,13 +232,13 @@
     // This would require a much newer kernel with newer ebpf accessors.
     // (This is also blindly assuming 12 bytes of tcp timestamp option in tcp header)
     uint64_t packets = 1;
-    uint64_t bytes = skb->len;
-    if (bytes > v->pmtu) {
-        const int tcp_overhead = sizeof(struct ipv6hdr) + sizeof(struct tcphdr) + 12;
-        const int mss = v->pmtu - tcp_overhead;
-        const uint64_t payload = bytes - tcp_overhead;
+    uint64_t L3_bytes = skb->len - l2_header_size;
+    if (L3_bytes > v->pmtu) {
+        const int tcp6_overhead = sizeof(struct ipv6hdr) + sizeof(struct tcphdr) + 12;
+        const int mss = v->pmtu - tcp6_overhead;
+        const uint64_t payload = L3_bytes - tcp6_overhead;
         packets = (payload + mss - 1) / mss;
-        bytes = tcp_overhead * packets + payload;
+        L3_bytes = tcp6_overhead * packets + payload;
     }
 
     // Are we past the limit?  If so, then abort...
@@ -247,7 +247,7 @@
     // a packet we let the core stack deal with things.
     // (The core stack needs to handle limits correctly anyway,
     // since we don't offload all traffic in both directions)
-    if (stat_v->rxBytes + stat_v->txBytes + bytes > *limit_v) TC_PUNT(LIMIT_REACHED);
+    if (stat_v->rxBytes + stat_v->txBytes + L3_bytes > *limit_v) TC_PUNT(LIMIT_REACHED);
 
     if (!is_ethernet) {
         // Try to inject an ethernet header, and simply return if we fail.
@@ -287,7 +287,7 @@
     bpf_csum_update(skb, 0xFFFF - ntohs(old_hl) + ntohs(new_hl));
 
     __sync_fetch_and_add(downstream ? &stat_v->rxPackets : &stat_v->txPackets, packets);
-    __sync_fetch_and_add(downstream ? &stat_v->rxBytes : &stat_v->txBytes, bytes);
+    __sync_fetch_and_add(downstream ? &stat_v->rxBytes : &stat_v->txBytes, L3_bytes);
 
     // Overwrite any mac header with the new one
     // For a rawip tx interface it will simply be a bunch of zeroes and later stripped.
@@ -449,13 +449,13 @@
     // This would require a much newer kernel with newer ebpf accessors.
     // (This is also blindly assuming 12 bytes of tcp timestamp option in tcp header)
     uint64_t packets = 1;
-    uint64_t bytes = skb->len;
-    if (bytes > v->pmtu) {
-        const int tcp_overhead = sizeof(struct iphdr) + sizeof(struct tcphdr) + 12;
-        const int mss = v->pmtu - tcp_overhead;
-        const uint64_t payload = bytes - tcp_overhead;
+    uint64_t L3_bytes = skb->len - l2_header_size;
+    if (L3_bytes > v->pmtu) {
+        const int tcp4_overhead = sizeof(struct iphdr) + sizeof(struct tcphdr) + 12;
+        const int mss = v->pmtu - tcp4_overhead;
+        const uint64_t payload = L3_bytes - tcp4_overhead;
         packets = (payload + mss - 1) / mss;
-        bytes = tcp_overhead * packets + payload;
+        L3_bytes = tcp4_overhead * packets + payload;
     }
 
     // Are we past the limit?  If so, then abort...
@@ -464,7 +464,7 @@
     // a packet we let the core stack deal with things.
     // (The core stack needs to handle limits correctly anyway,
     // since we don't offload all traffic in both directions)
-    if (stat_v->rxBytes + stat_v->txBytes + bytes > *limit_v) TC_PUNT(LIMIT_REACHED);
+    if (stat_v->rxBytes + stat_v->txBytes + L3_bytes > *limit_v) TC_PUNT(LIMIT_REACHED);
 
     if (!is_ethernet) {
         // Try to inject an ethernet header, and simply return if we fail.
@@ -540,7 +540,7 @@
     if (updatetime) v->last_used = bpf_ktime_get_boot_ns();
 
     __sync_fetch_and_add(downstream ? &stat_v->rxPackets : &stat_v->txPackets, packets);
-    __sync_fetch_and_add(downstream ? &stat_v->rxBytes : &stat_v->txBytes, bytes);
+    __sync_fetch_and_add(downstream ? &stat_v->rxBytes : &stat_v->txBytes, L3_bytes);
 
     // Redirect to forwarded interface.
     //
diff --git a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
index b6902b5..c28ee64 100755
--- a/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
+++ b/tests/cts/hostside/app/src/com/android/cts/net/hostside/VpnTest.java
@@ -1274,6 +1274,31 @@
     }
 
     @Test
+    public void testSocketClosed() throws Exception {
+        assumeTrue(supportedHardware());
+
+        final FileDescriptor localFd = openSocketFd(TEST_HOST, 80, TIMEOUT_MS);
+        final List<FileDescriptor> remoteFds = new ArrayList<>();
+
+        for (int i = 0; i < 30; i++) {
+            remoteFds.add(openSocketFdInOtherApp(TEST_HOST, 80, TIMEOUT_MS));
+        }
+
+        final String allowedApps = mRemoteSocketFactoryClient.getPackageName() + "," + mPackageName;
+        startVpn(new String[] {"192.0.2.2/32", "2001:db8:1:2::ffe/128"},
+                new String[] {"192.0.2.0/24", "2001:db8::/32"},
+                allowedApps, "", null, null /* underlyingNetworks */, false /* isAlwaysMetered */);
+
+        // Socket owned by VPN uid is not closed
+        assertSocketStillOpen(localFd, TEST_HOST);
+
+        // Sockets not owned by VPN uid are closed
+        for (final FileDescriptor remoteFd: remoteFds) {
+            assertSocketClosed(remoteFd, TEST_HOST);
+        }
+    }
+
+    @Test
     public void testExcludedRoutes() throws Exception {
         assumeTrue(supportedHardware());
         assumeTrue(SdkLevel.isAtLeastT());
diff --git a/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java b/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
index 603779d..3ca4775 100644
--- a/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
+++ b/tests/cts/hostside/src/com/android/cts/net/HostsideVpnTests.java
@@ -51,6 +51,10 @@
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testAppDisallowed");
     }
 
+    public void testSocketClosed() throws Exception {
+        runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testSocketClosed");
+    }
+
     public void testGetConnectionOwnerUidSecurity() throws Exception {
         runDeviceTests(TEST_PKG, TEST_PKG + ".VpnTest", "testGetConnectionOwnerUidSecurity");
     }