ebpf offload - add support for tethering of ipv4 udp

Test: TreeHugger
Signed-off-by: Maciej Żenczykowski <maze@google.com>
Change-Id: I6229e33cb219e9acd14f5b880cfa2ea0a34442f4
diff --git a/Tethering/bpf_progs/bpf_tethering.h b/Tethering/bpf_progs/bpf_tethering.h
index 4e006db..6591e81 100644
--- a/Tethering/bpf_progs/bpf_tethering.h
+++ b/Tethering/bpf_progs/bpf_tethering.h
@@ -44,6 +44,7 @@
     ERR(SHORT_L4_HEADER)     \
     ERR(SHORT_TCP_HEADER)    \
     ERR(SHORT_UDP_HEADER)    \
+    ERR(UDP_CSUM_ZERO)       \
     ERR(TRUNCATED_IPV4)      \
     ERR(_MAX)
 
diff --git a/Tethering/bpf_progs/offload.c b/Tethering/bpf_progs/offload.c
index 30b1f49..ca1a35b 100644
--- a/Tethering/bpf_progs/offload.c
+++ b/Tethering/bpf_progs/offload.c
@@ -428,6 +428,36 @@
     } else { // UDP
         // Make sure we can get at the udp header
         if (data + l2_header_size + sizeof(*ip) + sizeof(*udph) > data_end) PUNT(SHORT_UDP_HEADER);
+
+        // Skip handling of CHECKSUM_COMPLETE packets with udp checksum zero due to need for
+        // additional updating of skb->csum (this could be fixed up manually with more effort).
+        //
+        // Note that the in-kernel implementation of 'int64_t bpf_csum_update(skb, u32 csum)' is:
+        //   if (skb->ip_summed == CHECKSUM_COMPLETE)
+        //     return (skb->csum = csum_add(skb->csum, csum));
+        //   else
+        //     return -ENOTSUPP;
+        //
+        // So this will punt any CHECKSUM_COMPLETE packet with a zero UDP checksum,
+        // and leave all other packets unaffected (since it just at most adds zero to skb->csum).
+        //
+        // In practice this should almost never trigger because most nics do not generate
+        // CHECKSUM_COMPLETE packets on receive - especially so for nics/drivers on a phone.
+        //
+        // Additionally since we're forwarding, in most cases the value of the skb->csum field
+        // shouldn't matter (it's not used by physical nic egress).
+        //
+        // It only matters if we're ingressing through a CHECKSUM_COMPLETE capable nic
+        // and egressing through a virtual interface looping back to the kernel itself
+        // (ie. something like veth) where the CHECKSUM_COMPLETE/skb->csum can get reused
+        // on ingress.
+        //
+        // If we were in the kernel we'd simply probably call
+        //   void skb_checksum_complete_unset(struct sk_buff *skb) {
+        //     if (skb->ip_summed == CHECKSUM_COMPLETE) skb->ip_summed = CHECKSUM_NONE;
+        //   }
+        // here instead.  Perhaps there should be a bpf helper for that?
+        if (!udph->check && (bpf_csum_update(skb, 0) >= 0)) PUNT(UDP_CSUM_ZERO);
     }
 
     Tether4Key k = {
@@ -486,8 +516,6 @@
     // since we don't offload all traffic in both directions)
     if (stat_v->rxBytes + stat_v->txBytes + bytes > *limit_v) PUNT(LIMIT_REACHED);
 
-    if (!is_tcp) PUNT(NON_TCP);  // TEMP HACK: will remove once UDP is supported.
-
     if (!is_ethernet) {
         // Try to inject an ethernet header, and simply return if we fail.
         // We do this even if TX interface is RAWIP and thus does not need an ethernet header,
@@ -520,26 +548,33 @@
     // For a rawip tx interface it will simply be a bunch of zeroes and later stripped.
     *eth = v->macHeader;
 
+    const int l4_offs_csum = is_tcp ? ETH_IP4_TCP_OFFSET(check) : ETH_IP4_UDP_OFFSET(check);
     const int sz4 = sizeof(__be32);
+    // UDP 0 is special and stored as FFFF (this flag also causes a csum of 0 to be unmodified)
+    const int l4_flags = is_tcp ? 0 : BPF_F_MARK_MANGLED_0;
     const __be32 old_daddr = k.dst4.s_addr;
     const __be32 old_saddr = k.src4.s_addr;
     const __be32 new_daddr = v->dst46.s6_addr32[3];
     const __be32 new_saddr = v->src46.s6_addr32[3];
 
-    bpf_l4_csum_replace(skb, ETH_IP4_TCP_OFFSET(check), old_daddr, new_daddr, sz4 | BPF_F_PSEUDO_HDR);
+    bpf_l4_csum_replace(skb, l4_offs_csum, old_daddr, new_daddr, sz4 | BPF_F_PSEUDO_HDR | l4_flags);
     bpf_l3_csum_replace(skb, ETH_IP4_OFFSET(check), old_daddr, new_daddr, sz4);
     bpf_skb_store_bytes(skb, ETH_IP4_OFFSET(daddr), &new_daddr, sz4, 0);
 
-    bpf_l4_csum_replace(skb, ETH_IP4_TCP_OFFSET(check), old_saddr, new_saddr, sz4 | BPF_F_PSEUDO_HDR);
+    bpf_l4_csum_replace(skb, l4_offs_csum, old_saddr, new_saddr, sz4 | BPF_F_PSEUDO_HDR | l4_flags);
     bpf_l3_csum_replace(skb, ETH_IP4_OFFSET(check), old_saddr, new_saddr, sz4);
     bpf_skb_store_bytes(skb, ETH_IP4_OFFSET(saddr), &new_saddr, sz4, 0);
 
     const int sz2 = sizeof(__be16);
-    bpf_l4_csum_replace(skb, ETH_IP4_TCP_OFFSET(check), k.srcPort, v->srcPort, sz2);
-    bpf_skb_store_bytes(skb, ETH_IP4_TCP_OFFSET(source), &v->srcPort, sz2, 0);
+    // The offsets for TCP and UDP ports: source (u16 @ L4 offset 0) & dest (u16 @ L4 offset 2) are
+    // actually the same, so the compiler should just optimize them both down to a constant.
+    bpf_l4_csum_replace(skb, l4_offs_csum, k.srcPort, v->srcPort, sz2 | l4_flags);
+    bpf_skb_store_bytes(skb, is_tcp ? ETH_IP4_TCP_OFFSET(source) : ETH_IP4_UDP_OFFSET(source),
+                        &v->srcPort, sz2, 0);
 
-    bpf_l4_csum_replace(skb, ETH_IP4_TCP_OFFSET(check), k.dstPort, v->dstPort, sz2);
-    bpf_skb_store_bytes(skb, ETH_IP4_TCP_OFFSET(dest), &v->dstPort, sz2, 0);
+    bpf_l4_csum_replace(skb, l4_offs_csum, k.dstPort, v->dstPort, sz2 | l4_flags);
+    bpf_skb_store_bytes(skb, is_tcp ? ETH_IP4_TCP_OFFSET(dest) : ETH_IP4_UDP_OFFSET(dest),
+                        &v->dstPort, sz2, 0);
 
     // TEMP HACK: lack of TTL decrement