Merge "dscpPolicy: prevent cache overflow" into main
diff --git a/bpf/progs/dscpPolicy.c b/bpf/progs/dscpPolicy.c
index baabb02..39f2961 100644
--- a/bpf/progs/dscpPolicy.c
+++ b/bpf/progs/dscpPolicy.c
@@ -23,7 +23,10 @@
 #define ECN_MASK 3
 #define UPDATE_TOS(dscp, tos) ((dscp) << 2) | ((tos) & ECN_MASK)
 
-DEFINE_BPF_MAP_GRW(socket_policy_cache_map, HASH, uint64_t, RuleEntry, CACHE_MAP_SIZE, AID_SYSTEM)
+// The cache is never read nor written by userspace and is indexed by socket cookie % CACHE_MAP_SIZE
+#define CACHE_MAP_SIZE 32  // should be a power of two so we can % cheaply
+DEFINE_BPF_MAP_GRO(socket_policy_cache_map, PERCPU_ARRAY, uint32_t, RuleEntry, CACHE_MAP_SIZE,
+                   AID_SYSTEM)
 
 DEFINE_BPF_MAP_GRW(ipv4_dscp_policies_map, ARRAY, uint32_t, DscpPolicy, MAX_POLICIES, AID_SYSTEM)
 DEFINE_BPF_MAP_GRW(ipv6_dscp_policies_map, ARRAY, uint32_t, DscpPolicy, MAX_POLICIES, AID_SYSTEM)
@@ -43,6 +46,8 @@
     uint64_t cookie = bpf_get_socket_cookie(skb);
     if (!cookie) return;
 
+    uint32_t cacheid = cookie % CACHE_MAP_SIZE;
+
     __be16 sport = 0;
     uint16_t dport = 0;
     uint8_t protocol = 0;  // TODO: Use are reserved value? Or int (-1) and cast to uint below?
@@ -105,7 +110,8 @@
             return;
     }
 
-    RuleEntry* existing_rule = bpf_socket_policy_cache_map_lookup_elem(&cookie);
+    // this array lookup cannot actually fail
+    RuleEntry* existing_rule = bpf_socket_policy_cache_map_lookup_elem(&cacheid);
 
     if (existing_rule &&
         v6_equal(src_ip, existing_rule->src_ip) &&
@@ -192,7 +198,7 @@
     };
 
     // Update cache with found policy.
-    bpf_socket_policy_cache_map_update_elem(&cookie, &value, BPF_ANY);
+    bpf_socket_policy_cache_map_update_elem(&cacheid, &value, BPF_ANY);
 
     if (new_dscp < 0) return;
 
diff --git a/bpf/progs/dscpPolicy.h b/bpf/progs/dscpPolicy.h
index dc431a7..6a6b711 100644
--- a/bpf/progs/dscpPolicy.h
+++ b/bpf/progs/dscpPolicy.h
@@ -14,7 +14,6 @@
  * limitations under the License.
  */
 
-#define CACHE_MAP_SIZE 1024
 #define MAX_POLICIES 16
 
 #define STRUCT_SIZE(name, size) _Static_assert(sizeof(name) == (size), "Incorrect struct size.")