DscpPolicy BPF IPv4 Checksum Offset, DSCP Value Storage

Use the correct offset for the eth header when calculating the
checksum.

Store the DSCP value in BPF map for subsequent skb's instead
of the TOS so the calculation and checksum is correct.

Bug: 234808633

Change-Id: Ib40d4575455f34a8970eca8751b590319e2ee1ad
diff --git a/bpf_progs/dscp_policy.c b/bpf_progs/dscp_policy.c
index 38e1050..92ea0e2 100644
--- a/bpf_progs/dscp_policy.c
+++ b/bpf_progs/dscp_policy.c
@@ -14,17 +14,17 @@
  * limitations under the License.
  */
 
-#include <linux/types.h>
 #include <linux/bpf.h>
+#include <linux/if_ether.h>
 #include <linux/if_packet.h>
 #include <linux/ip.h>
 #include <linux/ipv6.h>
-#include <linux/if_ether.h>
 #include <linux/pkt_cls.h>
 #include <linux/tcp.h>
-#include <stdint.h>
+#include <linux/types.h>
 #include <netinet/in.h>
 #include <netinet/udp.h>
+#include <stdint.h>
 #include <string.h>
 
 // The resulting .o needs to load on the Android T beta 3 bpfloader
@@ -33,21 +33,25 @@
 #include "bpf_helpers.h"
 #include "dscp_policy.h"
 
+#define ECN_MASK 3
+#define IP4_OFFSET(field, header) (header + offsetof(struct iphdr, field))
+#define UPDATE_TOS(dscp, tos) (dscp << 2) | (tos & ECN_MASK)
+#define UPDATE_PRIORITY(dscp) ((dscp >> 2) + 0x60)
+#define UPDATE_FLOW_LABEL(dscp, flow_lbl) ((dscp & 0xf) << 6) + (flow_lbl >> 6)
+
 DEFINE_BPF_MAP_GRW(switch_comp_map, ARRAY, int, uint64_t, 1, AID_SYSTEM)
 
 DEFINE_BPF_MAP_GRW(ipv4_socket_to_policies_map_A, HASH, uint64_t, RuleEntry, MAX_POLICIES,
-        AID_SYSTEM)
+                   AID_SYSTEM)
 DEFINE_BPF_MAP_GRW(ipv4_socket_to_policies_map_B, HASH, uint64_t, RuleEntry, MAX_POLICIES,
-        AID_SYSTEM)
+                   AID_SYSTEM)
 DEFINE_BPF_MAP_GRW(ipv6_socket_to_policies_map_A, HASH, uint64_t, RuleEntry, MAX_POLICIES,
-        AID_SYSTEM)
+                   AID_SYSTEM)
 DEFINE_BPF_MAP_GRW(ipv6_socket_to_policies_map_B, HASH, uint64_t, RuleEntry, MAX_POLICIES,
-        AID_SYSTEM)
+                   AID_SYSTEM)
 
-DEFINE_BPF_MAP_GRW(ipv4_dscp_policies_map, ARRAY, uint32_t, DscpPolicy, MAX_POLICIES,
-        AID_SYSTEM)
-DEFINE_BPF_MAP_GRW(ipv6_dscp_policies_map, ARRAY, uint32_t, DscpPolicy, MAX_POLICIES,
-        AID_SYSTEM)
+DEFINE_BPF_MAP_GRW(ipv4_dscp_policies_map, ARRAY, uint32_t, DscpPolicy, MAX_POLICIES, AID_SYSTEM)
+DEFINE_BPF_MAP_GRW(ipv6_dscp_policies_map, ARRAY, uint32_t, DscpPolicy, MAX_POLICIES, AID_SYSTEM)
 
 static inline __always_inline void match_policy(struct __sk_buff* skb, bool ipv4, bool is_eth) {
     void* data = (void*)(long)skb->data;
@@ -69,21 +73,21 @@
 
     // used for map lookup
     uint64_t cookie = bpf_get_socket_cookie(skb);
-    if (!cookie)
-        return;
+    if (!cookie) return;
 
     uint16_t sport = 0;
     uint16_t dport = 0;
-    uint8_t protocol = 0; // TODO: Use are reserved value? Or int (-1) and cast to uint below?
+    uint8_t protocol = 0;  // TODO: Use are reserved value? Or int (-1) and cast to uint below?
     struct in6_addr srcIp = {};
     struct in6_addr dstIp = {};
-    uint8_t tos = 0; // Only used for IPv4
-    uint8_t priority = 0; // Only used for IPv6
-    uint8_t flow_lbl = 0; // Only used for IPv6
+    uint8_t tos = 0;       // Only used for IPv4
+    uint8_t priority = 0;  // Only used for IPv6
+    uint8_t flow_lbl = 0;  // Only used for IPv6
     if (ipv4) {
         const struct iphdr* const iph = is_eth ? (void*)(eth + 1) : data;
+        hdr_size = l2_header_size + sizeof(struct iphdr);
         // Must have ipv4 header
-        if (data + l2_header_size + sizeof(*iph) > data_end) return;
+        if (data + hdr_size > data_end) return;
 
         // IP version must be 4
         if (iph->version != 4) return;
@@ -100,11 +104,11 @@
         dstIp.s6_addr32[3] = iph->daddr;
         protocol = iph->protocol;
         tos = iph->tos;
-        hdr_size = sizeof(struct iphdr);
     } else {
         struct ipv6hdr* ip6h = is_eth ? (void*)(eth + 1) : data;
+        hdr_size = l2_header_size + sizeof(struct ipv6hdr);
         // Must have ipv6 header
-        if (data + l2_header_size + sizeof(*ip6h) > data_end) return;
+        if (data + hdr_size > data_end) return;
 
         if (ip6h->version != 6) return;
 
@@ -113,29 +117,24 @@
         protocol = ip6h->nexthdr;
         priority = ip6h->priority;
         flow_lbl = ip6h->flow_lbl[0];
-        hdr_size = sizeof(struct ipv6hdr);
     }
 
     switch (protocol) {
         case IPPROTO_UDP:
-        case IPPROTO_UDPLITE:
-        {
-            struct udphdr *udp;
+        case IPPROTO_UDPLITE: {
+            struct udphdr* udp;
             udp = data + hdr_size;
             if ((void*)(udp + 1) > data_end) return;
             sport = udp->source;
             dport = udp->dest;
-        }
-        break;
-        case IPPROTO_TCP:
-        {
-            struct tcphdr *tcp;
+        } break;
+        case IPPROTO_TCP: {
+            struct tcphdr* tcp;
             tcp = data + hdr_size;
             if ((void*)(tcp + 1) > data_end) return;
             sport = tcp->source;
             dport = tcp->dest;
-        }
-        break;
+        } break;
         default:
             return;
     }
@@ -156,22 +155,19 @@
     }
 
     if (existingRule && v6_equal(srcIp, existingRule->srcIp) &&
-                v6_equal(dstIp, existingRule->dstIp) &&
-                skb->ifindex == existingRule->ifindex &&
-                ntohs(sport) == htons(existingRule->srcPort) &&
-                ntohs(dport) == htons(existingRule->dstPort) &&
-                protocol == existingRule->proto) {
+        v6_equal(dstIp, existingRule->dstIp) && skb->ifindex == existingRule->ifindex &&
+        ntohs(sport) == htons(existingRule->srcPort) &&
+        ntohs(dport) == htons(existingRule->dstPort) && protocol == existingRule->proto) {
         if (ipv4) {
-            int ecn = tos & 3;
-            uint8_t newDscpVal = (existingRule->dscpVal << 2) + ecn;
-            int oldDscpVal = tos >> 2;
-            bpf_l3_csum_replace(skb, 1, oldDscpVal, newDscpVal, sizeof(uint8_t));
-            bpf_skb_store_bytes(skb, 1, &newDscpVal, sizeof(uint8_t), 0);
+            uint8_t newTos = UPDATE_TOS(existingRule->dscpVal, tos);
+            bpf_l3_csum_replace(skb, IP4_OFFSET(check, l2_header_size), htons(tos), htons(newTos),
+                                sizeof(uint16_t));
+            bpf_skb_store_bytes(skb, IP4_OFFSET(tos, l2_header_size), &newTos, sizeof(newTos), 0);
         } else {
-            uint8_t new_priority = (existingRule->dscpVal >> 2) + 0x60;
-            uint8_t new_flow_label = ((existingRule->dscpVal & 0xf) << 6) + (priority >> 6);
-            bpf_skb_store_bytes(skb, 0, &new_priority, sizeof(uint8_t), 0);
-            bpf_skb_store_bytes(skb, 1, &new_flow_label, sizeof(uint8_t), 0);
+            uint8_t new_priority = UPDATE_PRIORITY(existingRule->dscpVal);
+            uint8_t new_flow_label = UPDATE_FLOW_LABEL(existingRule->dscpVal, flow_lbl);
+            bpf_skb_store_bytes(skb, 0 + l2_header_size, &new_priority, sizeof(uint8_t), 0);
+            bpf_skb_store_bytes(skb, 1 + l2_header_size, &new_flow_label, sizeof(uint8_t), 0);
         }
         return;
     }
@@ -196,32 +192,31 @@
 
         // If the policy lookup failed, presentFields is 0, or iface index does not match
         // index on skb buff, then we can continue to next policy.
-        if (!policy || policy->presentFields == 0 || policy->ifindex != skb->ifindex)
-            continue;
+        if (!policy || policy->presentFields == 0 || policy->ifindex != skb->ifindex) continue;
 
         if ((policy->presentFields & SRC_IP_MASK_FLAG) == SRC_IP_MASK_FLAG &&
-                v6_equal(srcIp, policy->srcIp)) {
+            v6_equal(srcIp, policy->srcIp)) {
             score++;
             tempMask |= SRC_IP_MASK_FLAG;
         }
         if ((policy->presentFields & DST_IP_MASK_FLAG) == DST_IP_MASK_FLAG &&
-                v6_equal(dstIp, policy->dstIp)) {
+            v6_equal(dstIp, policy->dstIp)) {
             score++;
             tempMask |= DST_IP_MASK_FLAG;
         }
         if ((policy->presentFields & SRC_PORT_MASK_FLAG) == SRC_PORT_MASK_FLAG &&
-                ntohs(sport) == htons(policy->srcPort)) {
+            ntohs(sport) == htons(policy->srcPort)) {
             score++;
             tempMask |= SRC_PORT_MASK_FLAG;
         }
         if ((policy->presentFields & DST_PORT_MASK_FLAG) == DST_PORT_MASK_FLAG &&
-                ntohs(dport) >= htons(policy->dstPortStart) &&
-                ntohs(dport) <= htons(policy->dstPortEnd)) {
+            ntohs(dport) >= htons(policy->dstPortStart) &&
+            ntohs(dport) <= htons(policy->dstPortEnd)) {
             score++;
             tempMask |= DST_PORT_MASK_FLAG;
         }
         if ((policy->presentFields & PROTO_MASK_FLAG) == PROTO_MASK_FLAG &&
-                protocol == policy->proto) {
+            protocol == policy->proto) {
             score++;
             tempMask |= PROTO_MASK_FLAG;
         }
@@ -232,7 +227,8 @@
         }
     }
 
-    uint8_t new_tos= 0; // Can 0 be used as default forwarding value?
+    uint8_t new_tos = 0;  // Can 0 be used as default forwarding value?
+    uint8_t new_dscp = 0;
     uint8_t new_priority = 0;
     uint8_t new_flow_lbl = 0;
     if (bestScore > 0) {
@@ -244,20 +240,16 @@
         }
 
         if (policy) {
-            // TODO: if DSCP value is already set ignore?
+            new_dscp = policy->dscpVal;
             if (ipv4) {
-                int ecn = tos & 3;
-                new_tos = (policy->dscpVal << 2) + ecn;
+                new_tos = UPDATE_TOS(new_dscp, tos);
             } else {
-                new_priority = (policy->dscpVal >> 2) + 0x60;
-                new_flow_lbl = ((policy->dscpVal & 0xf) << 6) + (flow_lbl >> 6);
-
-                // Set IPv6 curDscp value to stored value and recalulate priority
-                // and flow label during next use.
-                new_tos = policy->dscpVal;
+                new_priority = UPDATE_PRIORITY(new_dscp);
+                new_flow_lbl = UPDATE_FLOW_LABEL(new_dscp, flow_lbl);
             }
         }
-    } else return;
+    } else
+        return;
 
     RuleEntry value = {
         .srcIp = srcIp,
@@ -266,10 +258,10 @@
         .srcPort = sport,
         .dstPort = dport,
         .proto = protocol,
-        .dscpVal = new_tos,
+        .dscpVal = new_dscp,
     };
 
-    //Update map with new policy.
+    // Update map with new policy.
     if (ipv4) {
         if (*selectedMap == MAP_A) {
             bpf_ipv4_socket_to_policies_map_A_update_elem(&cookie, &value, BPF_ANY);
@@ -286,12 +278,11 @@
 
     // Need to store bytes after updating map or program will not load.
     if (ipv4 && new_tos != (tos & 252)) {
-        int oldDscpVal = tos >> 2;
-        bpf_l3_csum_replace(skb, 1, oldDscpVal, new_tos, sizeof(uint8_t));
-        bpf_skb_store_bytes(skb, 1, &new_tos, sizeof(uint8_t), 0);
+        bpf_l3_csum_replace(skb, IP4_OFFSET(check, l2_header_size), htons(tos), htons(new_tos), 2);
+        bpf_skb_store_bytes(skb, IP4_OFFSET(tos, l2_header_size), &new_tos, sizeof(new_tos), 0);
     } else if (!ipv4 && (new_priority != priority || new_flow_lbl != flow_lbl)) {
-        bpf_skb_store_bytes(skb, 0, &new_priority, sizeof(uint8_t), 0);
-        bpf_skb_store_bytes(skb, 1, &new_flow_lbl, sizeof(uint8_t), 0);
+        bpf_skb_store_bytes(skb, l2_header_size, &new_priority, sizeof(new_priority), 0);
+        bpf_skb_store_bytes(skb, l2_header_size + 1, &new_flow_lbl, sizeof(new_flow_lbl), 0);
     }
     return;
 }
@@ -299,7 +290,6 @@
 DEFINE_BPF_PROG_KVER("schedcls/set_dscp_ether", AID_ROOT, AID_SYSTEM,
                      schedcls_set_dscp_ether, KVER(5, 15, 0))
 (struct __sk_buff* skb) {
-
     if (skb->pkt_type != PACKET_HOST) return TC_ACT_PIPE;
 
     if (skb->protocol == htons(ETH_P_IP)) {