Merge "DscpPolicy BPF IPv4 Checksum Offset, DSCP Value Storage" am: 5711788ea8

Original change: https://android-review.googlesource.com/c/platform/packages/modules/Connectivity/+/2125880

Change-Id: Icac13d2b3e6defc331588bc8d69004416f43bffd
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/bpf_progs/dscp_policy.c b/bpf_progs/dscp_policy.c
index 38e1050..92ea0e2 100644
--- a/bpf_progs/dscp_policy.c
+++ b/bpf_progs/dscp_policy.c
@@ -14,17 +14,17 @@
  * limitations under the License.
  */
 
-#include <linux/types.h>
 #include <linux/bpf.h>
+#include <linux/if_ether.h>
 #include <linux/if_packet.h>
 #include <linux/ip.h>
 #include <linux/ipv6.h>
-#include <linux/if_ether.h>
 #include <linux/pkt_cls.h>
 #include <linux/tcp.h>
-#include <stdint.h>
+#include <linux/types.h>
 #include <netinet/in.h>
 #include <netinet/udp.h>
+#include <stdint.h>
 #include <string.h>
 
 // The resulting .o needs to load on the Android T beta 3 bpfloader
@@ -33,21 +33,25 @@
 #include "bpf_helpers.h"
 #include "dscp_policy.h"
 
+#define ECN_MASK 3
+#define IP4_OFFSET(field, header) (header + offsetof(struct iphdr, field))
+#define UPDATE_TOS(dscp, tos) (dscp << 2) | (tos & ECN_MASK)
+#define UPDATE_PRIORITY(dscp) ((dscp >> 2) + 0x60)
+#define UPDATE_FLOW_LABEL(dscp, flow_lbl) ((dscp & 0xf) << 6) + (flow_lbl >> 6)
+
 DEFINE_BPF_MAP_GRW(switch_comp_map, ARRAY, int, uint64_t, 1, AID_SYSTEM)
 
 DEFINE_BPF_MAP_GRW(ipv4_socket_to_policies_map_A, HASH, uint64_t, RuleEntry, MAX_POLICIES,
-        AID_SYSTEM)
+                   AID_SYSTEM)
 DEFINE_BPF_MAP_GRW(ipv4_socket_to_policies_map_B, HASH, uint64_t, RuleEntry, MAX_POLICIES,
-        AID_SYSTEM)
+                   AID_SYSTEM)
 DEFINE_BPF_MAP_GRW(ipv6_socket_to_policies_map_A, HASH, uint64_t, RuleEntry, MAX_POLICIES,
-        AID_SYSTEM)
+                   AID_SYSTEM)
 DEFINE_BPF_MAP_GRW(ipv6_socket_to_policies_map_B, HASH, uint64_t, RuleEntry, MAX_POLICIES,
-        AID_SYSTEM)
+                   AID_SYSTEM)
 
-DEFINE_BPF_MAP_GRW(ipv4_dscp_policies_map, ARRAY, uint32_t, DscpPolicy, MAX_POLICIES,
-        AID_SYSTEM)
-DEFINE_BPF_MAP_GRW(ipv6_dscp_policies_map, ARRAY, uint32_t, DscpPolicy, MAX_POLICIES,
-        AID_SYSTEM)
+DEFINE_BPF_MAP_GRW(ipv4_dscp_policies_map, ARRAY, uint32_t, DscpPolicy, MAX_POLICIES, AID_SYSTEM)
+DEFINE_BPF_MAP_GRW(ipv6_dscp_policies_map, ARRAY, uint32_t, DscpPolicy, MAX_POLICIES, AID_SYSTEM)
 
 static inline __always_inline void match_policy(struct __sk_buff* skb, bool ipv4, bool is_eth) {
     void* data = (void*)(long)skb->data;
@@ -69,21 +73,21 @@
 
     // used for map lookup
     uint64_t cookie = bpf_get_socket_cookie(skb);
-    if (!cookie)
-        return;
+    if (!cookie) return;
 
     uint16_t sport = 0;
     uint16_t dport = 0;
-    uint8_t protocol = 0; // TODO: Use are reserved value? Or int (-1) and cast to uint below?
+    uint8_t protocol = 0;  // TODO: Use are reserved value? Or int (-1) and cast to uint below?
     struct in6_addr srcIp = {};
     struct in6_addr dstIp = {};
-    uint8_t tos = 0; // Only used for IPv4
-    uint8_t priority = 0; // Only used for IPv6
-    uint8_t flow_lbl = 0; // Only used for IPv6
+    uint8_t tos = 0;       // Only used for IPv4
+    uint8_t priority = 0;  // Only used for IPv6
+    uint8_t flow_lbl = 0;  // Only used for IPv6
     if (ipv4) {
         const struct iphdr* const iph = is_eth ? (void*)(eth + 1) : data;
+        hdr_size = l2_header_size + sizeof(struct iphdr);
         // Must have ipv4 header
-        if (data + l2_header_size + sizeof(*iph) > data_end) return;
+        if (data + hdr_size > data_end) return;
 
         // IP version must be 4
         if (iph->version != 4) return;
@@ -100,11 +104,11 @@
         dstIp.s6_addr32[3] = iph->daddr;
         protocol = iph->protocol;
         tos = iph->tos;
-        hdr_size = sizeof(struct iphdr);
     } else {
         struct ipv6hdr* ip6h = is_eth ? (void*)(eth + 1) : data;
+        hdr_size = l2_header_size + sizeof(struct ipv6hdr);
         // Must have ipv6 header
-        if (data + l2_header_size + sizeof(*ip6h) > data_end) return;
+        if (data + hdr_size > data_end) return;
 
         if (ip6h->version != 6) return;
 
@@ -113,29 +117,24 @@
         protocol = ip6h->nexthdr;
         priority = ip6h->priority;
         flow_lbl = ip6h->flow_lbl[0];
-        hdr_size = sizeof(struct ipv6hdr);
     }
 
     switch (protocol) {
         case IPPROTO_UDP:
-        case IPPROTO_UDPLITE:
-        {
-            struct udphdr *udp;
+        case IPPROTO_UDPLITE: {
+            struct udphdr* udp;
             udp = data + hdr_size;
             if ((void*)(udp + 1) > data_end) return;
             sport = udp->source;
             dport = udp->dest;
-        }
-        break;
-        case IPPROTO_TCP:
-        {
-            struct tcphdr *tcp;
+        } break;
+        case IPPROTO_TCP: {
+            struct tcphdr* tcp;
             tcp = data + hdr_size;
             if ((void*)(tcp + 1) > data_end) return;
             sport = tcp->source;
             dport = tcp->dest;
-        }
-        break;
+        } break;
         default:
             return;
     }
@@ -156,22 +155,19 @@
     }
 
     if (existingRule && v6_equal(srcIp, existingRule->srcIp) &&
-                v6_equal(dstIp, existingRule->dstIp) &&
-                skb->ifindex == existingRule->ifindex &&
-                ntohs(sport) == htons(existingRule->srcPort) &&
-                ntohs(dport) == htons(existingRule->dstPort) &&
-                protocol == existingRule->proto) {
+        v6_equal(dstIp, existingRule->dstIp) && skb->ifindex == existingRule->ifindex &&
+        ntohs(sport) == htons(existingRule->srcPort) &&
+        ntohs(dport) == htons(existingRule->dstPort) && protocol == existingRule->proto) {
         if (ipv4) {
-            int ecn = tos & 3;
-            uint8_t newDscpVal = (existingRule->dscpVal << 2) + ecn;
-            int oldDscpVal = tos >> 2;
-            bpf_l3_csum_replace(skb, 1, oldDscpVal, newDscpVal, sizeof(uint8_t));
-            bpf_skb_store_bytes(skb, 1, &newDscpVal, sizeof(uint8_t), 0);
+            uint8_t newTos = UPDATE_TOS(existingRule->dscpVal, tos);
+            bpf_l3_csum_replace(skb, IP4_OFFSET(check, l2_header_size), htons(tos), htons(newTos),
+                                sizeof(uint16_t));
+            bpf_skb_store_bytes(skb, IP4_OFFSET(tos, l2_header_size), &newTos, sizeof(newTos), 0);
         } else {
-            uint8_t new_priority = (existingRule->dscpVal >> 2) + 0x60;
-            uint8_t new_flow_label = ((existingRule->dscpVal & 0xf) << 6) + (priority >> 6);
-            bpf_skb_store_bytes(skb, 0, &new_priority, sizeof(uint8_t), 0);
-            bpf_skb_store_bytes(skb, 1, &new_flow_label, sizeof(uint8_t), 0);
+            uint8_t new_priority = UPDATE_PRIORITY(existingRule->dscpVal);
+            uint8_t new_flow_label = UPDATE_FLOW_LABEL(existingRule->dscpVal, flow_lbl);
+            bpf_skb_store_bytes(skb, 0 + l2_header_size, &new_priority, sizeof(uint8_t), 0);
+            bpf_skb_store_bytes(skb, 1 + l2_header_size, &new_flow_label, sizeof(uint8_t), 0);
         }
         return;
     }
@@ -196,32 +192,31 @@
 
         // If the policy lookup failed, presentFields is 0, or iface index does not match
         // index on skb buff, then we can continue to next policy.
-        if (!policy || policy->presentFields == 0 || policy->ifindex != skb->ifindex)
-            continue;
+        if (!policy || policy->presentFields == 0 || policy->ifindex != skb->ifindex) continue;
 
         if ((policy->presentFields & SRC_IP_MASK_FLAG) == SRC_IP_MASK_FLAG &&
-                v6_equal(srcIp, policy->srcIp)) {
+            v6_equal(srcIp, policy->srcIp)) {
             score++;
             tempMask |= SRC_IP_MASK_FLAG;
         }
         if ((policy->presentFields & DST_IP_MASK_FLAG) == DST_IP_MASK_FLAG &&
-                v6_equal(dstIp, policy->dstIp)) {
+            v6_equal(dstIp, policy->dstIp)) {
             score++;
             tempMask |= DST_IP_MASK_FLAG;
         }
         if ((policy->presentFields & SRC_PORT_MASK_FLAG) == SRC_PORT_MASK_FLAG &&
-                ntohs(sport) == htons(policy->srcPort)) {
+            ntohs(sport) == htons(policy->srcPort)) {
             score++;
             tempMask |= SRC_PORT_MASK_FLAG;
         }
         if ((policy->presentFields & DST_PORT_MASK_FLAG) == DST_PORT_MASK_FLAG &&
-                ntohs(dport) >= htons(policy->dstPortStart) &&
-                ntohs(dport) <= htons(policy->dstPortEnd)) {
+            ntohs(dport) >= htons(policy->dstPortStart) &&
+            ntohs(dport) <= htons(policy->dstPortEnd)) {
             score++;
             tempMask |= DST_PORT_MASK_FLAG;
         }
         if ((policy->presentFields & PROTO_MASK_FLAG) == PROTO_MASK_FLAG &&
-                protocol == policy->proto) {
+            protocol == policy->proto) {
             score++;
             tempMask |= PROTO_MASK_FLAG;
         }
@@ -232,7 +227,8 @@
         }
     }
 
-    uint8_t new_tos= 0; // Can 0 be used as default forwarding value?
+    uint8_t new_tos = 0;  // Can 0 be used as default forwarding value?
+    uint8_t new_dscp = 0;
     uint8_t new_priority = 0;
     uint8_t new_flow_lbl = 0;
     if (bestScore > 0) {
@@ -244,20 +240,16 @@
         }
 
         if (policy) {
-            // TODO: if DSCP value is already set ignore?
+            new_dscp = policy->dscpVal;
             if (ipv4) {
-                int ecn = tos & 3;
-                new_tos = (policy->dscpVal << 2) + ecn;
+                new_tos = UPDATE_TOS(new_dscp, tos);
             } else {
-                new_priority = (policy->dscpVal >> 2) + 0x60;
-                new_flow_lbl = ((policy->dscpVal & 0xf) << 6) + (flow_lbl >> 6);
-
-                // Set IPv6 curDscp value to stored value and recalulate priority
-                // and flow label during next use.
-                new_tos = policy->dscpVal;
+                new_priority = UPDATE_PRIORITY(new_dscp);
+                new_flow_lbl = UPDATE_FLOW_LABEL(new_dscp, flow_lbl);
             }
         }
-    } else return;
+    } else
+        return;
 
     RuleEntry value = {
         .srcIp = srcIp,
@@ -266,10 +258,10 @@
         .srcPort = sport,
         .dstPort = dport,
         .proto = protocol,
-        .dscpVal = new_tos,
+        .dscpVal = new_dscp,
     };
 
-    //Update map with new policy.
+    // Update map with new policy.
     if (ipv4) {
         if (*selectedMap == MAP_A) {
             bpf_ipv4_socket_to_policies_map_A_update_elem(&cookie, &value, BPF_ANY);
@@ -286,12 +278,11 @@
 
     // Need to store bytes after updating map or program will not load.
     if (ipv4 && new_tos != (tos & 252)) {
-        int oldDscpVal = tos >> 2;
-        bpf_l3_csum_replace(skb, 1, oldDscpVal, new_tos, sizeof(uint8_t));
-        bpf_skb_store_bytes(skb, 1, &new_tos, sizeof(uint8_t), 0);
+        bpf_l3_csum_replace(skb, IP4_OFFSET(check, l2_header_size), htons(tos), htons(new_tos), 2);
+        bpf_skb_store_bytes(skb, IP4_OFFSET(tos, l2_header_size), &new_tos, sizeof(new_tos), 0);
     } else if (!ipv4 && (new_priority != priority || new_flow_lbl != flow_lbl)) {
-        bpf_skb_store_bytes(skb, 0, &new_priority, sizeof(uint8_t), 0);
-        bpf_skb_store_bytes(skb, 1, &new_flow_lbl, sizeof(uint8_t), 0);
+        bpf_skb_store_bytes(skb, l2_header_size, &new_priority, sizeof(new_priority), 0);
+        bpf_skb_store_bytes(skb, l2_header_size + 1, &new_flow_lbl, sizeof(new_flow_lbl), 0);
     }
     return;
 }
@@ -299,7 +290,6 @@
 DEFINE_BPF_PROG_KVER("schedcls/set_dscp_ether", AID_ROOT, AID_SYSTEM,
                      schedcls_set_dscp_ether, KVER(5, 15, 0))
 (struct __sk_buff* skb) {
-
     if (skb->pkt_type != PACKET_HOST) return TC_ACT_PIPE;
 
     if (skb->protocol == htons(ETH_P_IP)) {