Merge "Allow init to upgrade MTE to sync."
diff --git a/libc/kernel/tools/cpp.py b/libc/kernel/tools/cpp.py
index a040474..1496231 100755
--- a/libc/kernel/tools/cpp.py
+++ b/libc/kernel/tools/cpp.py
@@ -155,6 +155,11 @@
     pass
 
 
+class UnparseableStruct(Exception):
+    """An exception that will be raised for structs that cannot be parsed."""
+    pass
+
+
 # The __contains__ function in libclang SourceRange class contains a bug. It
 # gives wrong result when dealing with single line range.
 # Bug filed with upstream:
@@ -1197,22 +1202,38 @@
 
     def removeStructs(self, structs):
         """Remove structs."""
-        for b in self.blocks:
+        extra_includes = []
+        block_num = 0
+        num_blocks = len(self.blocks)
+        while block_num < num_blocks:
+            b = self.blocks[block_num]
+            block_num += 1
             # Have to look in each block for a top-level struct definition.
             if b.directive:
                 continue
             num_tokens = len(b.tokens)
-            # A struct definition has at least 5 tokens:
+            # A struct definition usually looks like:
             #   struct
             #   ident
             #   {
             #   }
             #   ;
-            if num_tokens < 5:
+            # However, the structure might be spread across multiple blocks
+            # if the structure looks like this:
+            #   struct ident
+            #   {
+            #   #ifdef VARIABLE
+            #     pid_t pid;
+            #   #endif
+            #   }:
+            # So the total number of tokens in the block might be less than
+            # five but assume at least three.
+            if num_tokens < 3:
                 continue
+
             # This is a simple struct finder, it might fail if a top-level
             # structure has an #if type directives that confuses the algorithm
-            # for finding th end of the structure. Or if there is another
+            # for finding the end of the structure. Or if there is another
             # structure definition embedded in the structure.
             i = 0
             while i < num_tokens - 2:
@@ -1223,24 +1244,58 @@
                 if (b.tokens[i + 1].kind == TokenKind.IDENTIFIER and
                     b.tokens[i + 2].kind == TokenKind.PUNCTUATION and
                     b.tokens[i + 2].id == "{" and b.tokens[i + 1].id in structs):
+                    # Add an include for the structure to be removed of the form:
+                    #  #include <bits/STRUCT_NAME.h>
+                    struct_token = b.tokens[i + 1]
+                    if not structs[struct_token.id]:
+                        extra_includes.append("<bits/%s.h>" % struct_token.id)
+
                     # Search forward for the end of the structure.
-                    # Very simple search, look for } and ; tokens. If something
-                    # more complicated is needed we can add it later.
+                    # Very simple search, look for } and ; tokens.
+                    # If we hit the end of the block, we'll need to start
+                    # looking at the next block.
                     j = i + 3
-                    while j < num_tokens - 1:
-                        if (b.tokens[j].kind == TokenKind.PUNCTUATION and
-                            b.tokens[j].id == "}" and
-                            b.tokens[j + 1].kind == TokenKind.PUNCTUATION and
-                            b.tokens[j + 1].id == ";"):
-                            b.tokens = b.tokens[0:i] + b.tokens[j + 2:num_tokens]
+                    depth = 1
+                    struct_removed = False
+                    while not struct_removed:
+                        while j < num_tokens:
+                            if b.tokens[j].kind == TokenKind.PUNCTUATION:
+                                if b.tokens[j].id == '{':
+                                    depth += 1
+                                elif b.tokens[j].id == '}':
+                                    depth -= 1
+                                elif b.tokens[j].id == ';' and depth == 0:
+                                    b.tokens = b.tokens[0:i] + b.tokens[j + 1:num_tokens]
+                                    num_tokens = len(b.tokens)
+                                    struct_removed = True
+                                    break
+                            j += 1
+                        if not struct_removed:
+                            b.tokens = b.tokens[0:i]
+
+                            # Skip directive blocks.
+                            start_block = block_num
+                            while block_num < num_blocks:
+                                if not self.blocks[block_num].directive:
+                                    break
+                                block_num += 1
+                            if block_num >= num_blocks:
+                                # Unparsable struct, error out.
+                                raise UnparseableStruct("Cannot remove struct %s: %s" % (struct_token.id, struct_token.location))
+                            self.blocks = self.blocks[0:start_block] + self.blocks[block_num:num_blocks]
+                            num_blocks = len(self.blocks)
+                            b = self.blocks[start_block]
+                            block_num = start_block + 1
                             num_tokens = len(b.tokens)
-                            j = i
-                            break
-                        j += 1
-                    i = j
+                            i = 0
+                            j = 0
                     continue
                 i += 1
 
+        for extra_include in extra_includes:
+            replacement = CppStringTokenizer(extra_include)
+            self.blocks.insert(2, Block(replacement.tokens, directive='include'))
+
     def optimizeAll(self, macros):
         self.optimizeMacros(macros)
         self.optimizeIf01()
@@ -1404,35 +1459,12 @@
 
     def replaceTokens(self, replacements):
         """Replace tokens according to the given dict."""
-        extra_includes = []
         for b in self.blocks:
             made_change = False
             if b.isInclude() is None:
                 i = 0
                 while i < len(b.tokens):
                     tok = b.tokens[i]
-                    if (tok.kind == TokenKind.KEYWORD and tok.id == 'struct'
-                        and (i + 2) < len(b.tokens) and b.tokens[i + 2].id == '{'):
-                        struct_name = b.tokens[i + 1].id
-                        if struct_name in kernel_struct_replacements:
-                            extra_includes.append("<bits/%s.h>" % struct_name)
-                            end = i + 2
-                            depth = 1
-                            while end < len(b.tokens) and depth > 0:
-                                if b.tokens[end].id == '}':
-                                    depth -= 1
-                                elif b.tokens[end].id == '{':
-                                    depth += 1
-                                end += 1
-                            end += 1 # Swallow last '}'
-                            while end < len(b.tokens) and b.tokens[end].id != ';':
-                                end += 1
-                            end += 1 # Swallow ';'
-                            # Remove these tokens. We'll replace them later with a #include block.
-                            b.tokens[i:end] = []
-                            made_change = True
-                            # We've just modified b.tokens, so revisit the current offset.
-                            continue
                     if tok.kind == TokenKind.IDENTIFIER:
                         if tok.id in replacements:
                             tok.id = replacements[tok.id]
@@ -1447,10 +1479,6 @@
                 # Keep 'expr' in sync with 'tokens'.
                 b.expr = CppExpr(b.tokens)
 
-        for extra_include in extra_includes:
-            replacement = CppStringTokenizer(extra_include)
-            self.blocks.insert(2, Block(replacement.tokens, directive='include'))
-
 
 
 def strip_space(s):
@@ -2020,7 +2048,7 @@
   struct timeval val2;
 };
 """
-        self.assertEqual(self.parse(text, set(["remove"])), expected)
+        self.assertEqual(self.parse(text, {"remove": True}), expected)
 
     def test_remove_struct_from_end(self):
         text = """\
@@ -2039,7 +2067,7 @@
   struct timeval val2;
 };
 """
-        self.assertEqual(self.parse(text, set(["remove"])), expected)
+        self.assertEqual(self.parse(text, {"remove": True}), expected)
 
     def test_remove_minimal_struct(self):
         text = """\
@@ -2047,7 +2075,7 @@
 };
 """
         expected = "";
-        self.assertEqual(self.parse(text, set(["remove"])), expected)
+        self.assertEqual(self.parse(text, {"remove": True}), expected)
 
     def test_remove_struct_with_struct_fields(self):
         text = """\
@@ -2067,7 +2095,7 @@
   struct remove val2;
 };
 """
-        self.assertEqual(self.parse(text, set(["remove"])), expected)
+        self.assertEqual(self.parse(text, {"remove": True}), expected)
 
     def test_remove_consecutive_structs(self):
         text = """\
@@ -2099,7 +2127,7 @@
   struct timeval val2;
 };
 """
-        self.assertEqual(self.parse(text, set(["remove1", "remove2"])), expected)
+        self.assertEqual(self.parse(text, {"remove1": True, "remove2": True}), expected)
 
     def test_remove_multiple_structs(self):
         text = """\
@@ -2132,7 +2160,101 @@
   int val;
 };
 """
-        self.assertEqual(self.parse(text, set(["remove1", "remove2"])), expected)
+        self.assertEqual(self.parse(text, {"remove1": True, "remove2": True}), expected)
+
+    def test_remove_struct_with_inline_structs(self):
+        text = """\
+struct remove {
+  int val1;
+  int val2;
+  struct {
+    int val1;
+    struct {
+      int val1;
+    } level2;
+  } level1;
+};
+struct something {
+  struct timeval val1;
+  struct timeval val2;
+};
+"""
+        expected = """\
+struct something {
+  struct timeval val1;
+  struct timeval val2;
+};
+"""
+        self.assertEqual(self.parse(text, {"remove": True}), expected)
+
+    def test_remove_struct_across_blocks(self):
+        text = """\
+struct remove {
+  int val1;
+  int val2;
+#ifdef PARAMETER1
+  PARAMETER1
+#endif
+#ifdef PARAMETER2
+  PARAMETER2
+#endif
+};
+struct something {
+  struct timeval val1;
+  struct timeval val2;
+};
+"""
+        expected = """\
+struct something {
+  struct timeval val1;
+  struct timeval val2;
+};
+"""
+        self.assertEqual(self.parse(text, {"remove": True}), expected)
+
+    def test_remove_struct_across_blocks_multiple_structs(self):
+        text = """\
+struct remove1 {
+  int val1;
+  int val2;
+#ifdef PARAMETER1
+  PARAMETER1
+#endif
+#ifdef PARAMETER2
+  PARAMETER2
+#endif
+};
+struct remove2 {
+};
+struct something {
+  struct timeval val1;
+  struct timeval val2;
+};
+"""
+        expected = """\
+struct something {
+  struct timeval val1;
+  struct timeval val2;
+};
+"""
+        self.assertEqual(self.parse(text, {"remove1": True, "remove2": True}), expected)
+
+    def test_remove_multiple_struct_and_add_includes(self):
+        text = """\
+struct remove1 {
+  int val1;
+  int val2;
+};
+struct remove2 {
+  struct timeval val1;
+  struct timeval val2;
+};
+"""
+        expected = """\
+#include <bits/remove1.h>
+#include <bits/remove2.h>
+"""
+        self.assertEqual(self.parse(text, {"remove1": False, "remove2": False}), expected)
 
 
 class FullPathTest(unittest.TestCase):
diff --git a/libc/kernel/tools/defaults.py b/libc/kernel/tools/defaults.py
index 2089285..99bbc3e 100644
--- a/libc/kernel/tools/defaults.py
+++ b/libc/kernel/tools/defaults.py
@@ -35,16 +35,23 @@
     "__kernel_old_timeval": "1",
     }
 
-# this is the set of known kernel data structures we want to remove from
-# the final headers
-kernel_structs_to_remove = set(
-        [
-          # Remove the structures since they are still the same as
-          # timeval, itimerval.
-          "__kernel_old_timeval",
-          "__kernel_old_itimerval",
-        ]
-    )
+# This is the set of known kernel data structures we want to remove from
+# the final headers. If the map value is False, that means that in
+# addition to removing the structure, add an #include <bits/STRUCT.h>
+# to the file.
+kernel_structs_to_remove = {
+    # Remove the structures since they are still the same as
+    # timeval, itimerval.
+    "__kernel_old_timeval": True,
+    "__kernel_old_itimerval": True,
+    # Replace all of the below structures with #include <bits/STRUCT.h>
+    "epoll_event": False,
+    "flock": False,
+    "flock64": False,
+    "in_addr": False,
+    "ip_mreq_source": False,
+    "ip_msfilter": False,
+    }
 
 # define to true if you want to remove all defined(CONFIG_FOO) tests
 # from the clean headers. testing shows that this is not strictly necessary
@@ -100,20 +107,6 @@
     }
 
 
-# This is the set of struct definitions that we want to replace with
-# a #include of <bits/struct.h> instead.
-kernel_struct_replacements = set(
-        [
-          "epoll_event",
-          "flock",
-          "flock64",
-          "in_addr",
-          "ip_mreq_source",
-          "ip_msfilter",
-        ]
-    )
-
-
 # This is the set of known static inline functions that we want to keep
 # in the final kernel headers.
 kernel_known_generic_statics = set(
diff --git a/libc/kernel/tools/update_all.py b/libc/kernel/tools/update_all.py
index f2e78da..abc72a4 100755
--- a/libc/kernel/tools/update_all.py
+++ b/libc/kernel/tools/update_all.py
@@ -27,8 +27,14 @@
 def ProcessFiles(updater, original_dir, modified_dir, src_rel_dir, update_rel_dir):
     # Delete the old headers before updating to the new headers.
     update_dir = os.path.join(get_kernel_dir(), update_rel_dir)
-    shutil.rmtree(update_dir)
-    os.mkdir(update_dir, 0o755)
+    for root, dirs, files in os.walk(update_dir, topdown=True):
+        for entry in files:
+            # BUILD is a special file that needs to be preserved.
+            if entry == "BUILD":
+                continue
+            os.remove(os.path.join(root, entry))
+        for entry in dirs:
+            shutil.rmtree(os.path.join(root, entry))
 
     src_dir = os.path.normpath(os.path.join(original_dir, src_rel_dir))
     src_dir_len = len(src_dir) + 1