Speed up seccomp with priority list.

Allow to provide a list of prioritized syscalls (e.g., syscalls that we
know occur often) which are checked before other syscalls in seccomp.

When constructing the bpf seccomp filter, traverse prioritized syscalls
in a linear list before checking all other syscalls in a binary tree.

Bug: 156732794
Test: make, inspect generated *_system_policy.cpp files
Test: simpleperf on futex/ioctl-heavy app seems to show 5-10% less time
spent in seccomp call

Change-Id: I509343bcd32ada90c0591785ab5cb12d2a38c31e
(cherry picked from commit ce84677733c18bc442f7f1b2f1840117c904db70)
diff --git a/libc/tools/genseccomp.py b/libc/tools/genseccomp.py
index cc0ff99..ba7e2ca 100755
--- a/libc/tools/genseccomp.py
+++ b/libc/tools/genseccomp.py
@@ -12,6 +12,7 @@
 
 
 BPF_JGE = "BPF_JUMP(BPF_JMP|BPF_JGE|BPF_K, {0}, {1}, {2})"
+BPF_JEQ = "BPF_JUMP(BPF_JMP|BPF_JEQ|BPF_K, {0}, {1}, {2})"
 BPF_ALLOW = "BPF_STMT(BPF_RET|BPF_K, SECCOMP_RET_ALLOW)"
 
 
@@ -37,6 +38,24 @@
   return set([x["name"] for x in parser.syscalls if x.get(architecture)])
 
 
+def load_syscall_priorities_from_file(file_path):
+  format_re = re.compile(r'^\s*([A-Za-z_][A-Za-z0-9_]+)\s*$')
+  priorities = []
+  with open(file_path) as f:
+    for line in f:
+      m = format_re.match(line)
+      if not m:
+        continue
+      try:
+        name = m.group(1)
+        priorities.append(name)
+      except:
+        logging.debug('Failed to parse %s from %s', (line, file_path))
+        pass
+
+  return priorities
+
+
 def merge_names(base_names, whitelist_names, blacklist_names):
   if bool(blacklist_names - base_names):
     raise RuntimeError("Blacklist item not in bionic - aborting " + str(
@@ -45,6 +64,20 @@
   return (base_names - blacklist_names) | whitelist_names
 
 
+def extract_priority_syscalls(syscalls, priorities):
+  # Extract syscalls that are not in the priority list
+  other_syscalls = \
+    [syscall for syscall in syscalls if syscall[0] not in priorities]
+  # For prioritized syscalls, keep the order in which they appear in th
+  # priority list
+  syscall_dict = {syscall[0]: syscall[1] for syscall in syscalls}
+  priority_syscalls = []
+  for name in priorities:
+    if name in syscall_dict.keys():
+      priority_syscalls.append((name, syscall_dict[name]))
+  return priority_syscalls, other_syscalls
+
+
 def parse_syscall_NRs(names_path):
   # The input is now the preprocessed source file. This will contain a lot
   # of junk from the preprocessor, but our lines will be in the format:
@@ -123,8 +156,21 @@
     return jump + first + second
 
 
-def convert_ranges_to_bpf(ranges):
-  bpf = convert_to_intermediate_bpf(ranges)
+# Converts the prioritized syscalls to a bpf list that  is prepended to the
+# tree generated by convert_to_intermediate_bpf(). If we hit one of these
+# syscalls, shortcut to the allow statement at the bottom of the tree
+# immediately
+def convert_priority_to_intermediate_bpf(priority_syscalls):
+  result = []
+  for i, syscall in enumerate(priority_syscalls):
+    result.append(BPF_JEQ.format(syscall[1], "{allow}", 0) +
+                  ", //" + syscall[0])
+  return result
+
+
+def convert_ranges_to_bpf(ranges, priority_syscalls):
+  bpf = convert_priority_to_intermediate_bpf(priority_syscalls) + \
+    convert_to_intermediate_bpf(ranges)
 
   # Now we know the size of the tree, we can substitute the {fail} and {allow}
   # placeholders
@@ -135,9 +181,8 @@
     # With bpfs jmp 0 means the next statement, so the distance to the end is
     # len(bpf) - i - 1, which is where we will put the kill statement, and
     # then the statement after that is the allow statement
-    if "{fail}" in statement and "{allow}" in statement:
-      bpf[i] = statement.format(fail=str(len(bpf) - i),
-                                allow=str(len(bpf) - i - 1))
+    bpf[i] = statement.format(fail=str(len(bpf) - i),
+                              allow=str(len(bpf) - i - 1))
 
   # Add the allow calls at the end. If the syscall is not matched, we will
   # continue. This allows the user to choose to match further syscalls, and
@@ -174,13 +219,15 @@
   return header + "\n".join(bpf) + footer
 
 
-def construct_bpf(syscalls, architecture, name_modifier):
-  ranges = convert_NRs_to_ranges(syscalls)
-  bpf = convert_ranges_to_bpf(ranges)
+def construct_bpf(syscalls, architecture, name_modifier, priorities):
+  priority_syscalls, other_syscalls = \
+    extract_priority_syscalls(syscalls, priorities)
+  ranges = convert_NRs_to_ranges(other_syscalls)
+  bpf = convert_ranges_to_bpf(ranges, priority_syscalls)
   return convert_bpf_to_output(bpf, architecture, name_modifier)
 
 
-def gen_policy(name_modifier, out_dir, base_syscall_file, syscall_files, syscall_NRs):
+def gen_policy(name_modifier, out_dir, base_syscall_file, syscall_files, syscall_NRs, priority_file):
   for arch in SupportedArchitectures:
     base_names = load_syscall_names_from_file(base_syscall_file, arch)
     whitelist_names = set()
@@ -190,6 +237,9 @@
         blacklist_names |= load_syscall_names_from_file(f, arch)
       else:
         whitelist_names |= load_syscall_names_from_file(f, arch)
+    priorities = []
+    if priority_file:
+      priorities = load_syscall_priorities_from_file(priority_file)
 
     allowed_syscalls = []
     for name in merge_names(base_names, whitelist_names, blacklist_names):
@@ -198,7 +248,7 @@
       except:
         logging.exception("Failed to find %s in %s", name, arch)
         raise
-    output = construct_bpf(allowed_syscalls, arch, name_modifier)
+    output = construct_bpf(allowed_syscalls, arch, name_modifier, priorities)
 
     # And output policy
     existing = ""
@@ -226,6 +276,7 @@
                             "following files: \n"
                             "* /blacklist.*\.txt$/ syscall blacklist.\n"
                             "* /whitelist.*\.txt$/ syscall whitelist.\n"
+                            "* /priority.txt$/ priorities for bpf rules.\n"
                             "* otherwise, syscall name-number mapping.\n"))
   args = parser.parse_args()
 
@@ -235,17 +286,21 @@
     logging.basicConfig(level=logging.INFO)
 
   syscall_files = []
+  priority_file = None
   syscall_NRs = {}
   for filename in args.files:
     if filename.lower().endswith('.txt'):
-      syscall_files.append(filename)
+      if filename.lower().endswith('priority.txt'):
+        priority_file = filename
+      else:
+        syscall_files.append(filename)
     else:
       m = re.search(r"libseccomp_gen_syscall_nrs_([^/]+)", filename)
       syscall_NRs[m.group(1)] = parse_syscall_NRs(filename)
 
   gen_policy(name_modifier=args.name_modifier, out_dir=args.out_dir,
              syscall_NRs=syscall_NRs, base_syscall_file=args.base_file,
-             syscall_files=args.files)
+             syscall_files=syscall_files, priority_file=priority_file)
 
 
 if __name__ == "__main__":