Merge "Use trie to store monolithic hidden API flags"
diff --git a/scripts/hiddenapi/verify_overlaps.py b/scripts/hiddenapi/verify_overlaps.py
index 8579321..a4a423e 100755
--- a/scripts/hiddenapi/verify_overlaps.py
+++ b/scripts/hiddenapi/verify_overlaps.py
@@ -19,12 +19,244 @@
 
 import argparse
 import csv
+import sys
 from itertools import chain
 
+class InteriorNode:
+    """
+    An interior node in a trie.
+
+    Each interior node has a dict that maps from an element of a signature to
+    either another interior node or a leaf. Each interior node represents either
+    a package, class or nested class. Class members are represented by a Leaf.
+
+    Associating the set of flags [public-api] with the signature
+    "Ljava/lang/Object;->String()Ljava/lang/String;" will cause the following
+    nodes to be created:
+    Node()
+    ^- package:java -> Node()
+       ^- package:lang -> Node()
+           ^- class:Object -> Node()
+              ^- member:String()Ljava/lang/String; -> Leaf([public-api])
+
+    Associating the set of flags [blocked,core-platform-api] with the signature
+    "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;"
+    will cause the following nodes to be created:
+    Node()
+    ^- package:java -> Node()
+       ^- package:lang -> Node()
+           ^- class:Character -> Node()
+              ^- class:UnicodeScript -> Node()
+                 ^- member:of(I)Ljava/lang/Character$UnicodeScript;
+                    -> Leaf([blocked,core-platform-api])
+
+    Attributes:
+        nodes: a dict from an element of the signature to the Node/Leaf
+        containing the next element/value.
+    """
+    def __init__(self):
+        self.nodes = {}
+
+    def signatureToElements(self, signature):
+        """
+        Split a signature or a prefix into a number of elements:
+        1. The packages (excluding the leading L preceding the first package).
+        2. The class names, from outermost to innermost.
+        3. The member signature.
+
+        e.g. Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;
+        will be broken down into these elements:
+        1. package:java
+        2. package:lang
+        3. class:Character
+        4. class:UnicodeScript
+        5. member:of(I)Ljava/lang/Character$UnicodeScript;
+        """
+        # Remove the leading L.
+        #  - java/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;
+        text = signature.removeprefix("L")
+        # Split the signature between qualified class name and the class member
+        # signature.
+        #  0 - java/lang/Character$UnicodeScript
+        #  1 - of(I)Ljava/lang/Character$UnicodeScript;
+        parts = text.split(";->")
+        member = parts[1:]
+        # Split the qualified class name into packages, and class name.
+        #  0 - java
+        #  1 - lang
+        #  2 - Character$UnicodeScript
+        elements = parts[0].split("/")
+        packages = elements[0:-1]
+        className = elements[-1]
+        if className == "*" or className == "**":
+            # Cannot specify a wildcard and target a specific member
+            if len(member) != 0:
+                raise Exception("Invalid signature %s: contains wildcard %s and member signature %s"
+                                % (signature, className, member[0]))
+            wildcard = [className]
+            # Assemble the parts into a single list, adding prefixes to identify
+            # the different parts.
+            #  0 - package:java
+            #  1 - package:lang
+            #  2 - *
+            return list(chain(map(lambda x : "package:" + x, packages),
+                              wildcard))
+        else:
+            # Split the class name into outer / inner classes
+            #  0 - Character
+            #  1 - UnicodeScript
+            classes = className.split("$")
+            # Assemble the parts into a single list, adding prefixes to identify
+            # the different parts.
+            #  0 - package:java
+            #  1 - package:lang
+            #  2 - class:Character
+            #  3 - class:UnicodeScript
+            #  4 - member:of(I)Ljava/lang/Character$UnicodeScript;
+            return list(chain(map(lambda x : "package:" + x, packages),
+                              map(lambda x : "class:" + x, classes),
+                              map(lambda x : "member:" + x, member)))
+
+    def add(self, signature, value):
+        """
+        Associate the value with the specific signature.
+        :param signature: the member signature
+        :param value: the value to associated with the signature
+        :return: n/a
+        """
+        # Split the signature into elements.
+        elements = self.signatureToElements(signature)
+        # Find the Node associated with the deepest class.
+        node = self
+        for element in elements[:-1]:
+            if element in node.nodes:
+                node = node.nodes[element]
+            else:
+                next = InteriorNode()
+                node.nodes[element] = next
+                node = next
+        # Add a Leaf containing the value and associate it with the member
+        # signature within the class.
+        lastElement = elements[-1]
+        if not lastElement.startswith("member:"):
+            raise Exception("Invalid signature: %s, does not identify a specific member" % signature)
+        if lastElement in node.nodes:
+            raise Exception("Duplicate signature: %s" % signature)
+        node.nodes[lastElement] = Leaf(value)
+
+    def getMatchingRows(self, pattern):
+        """
+        Get the values (plural) associated with the pattern.
+
+        e.g. If the pattern is a full signature then this will return a list
+        containing the value associated with that signature.
+
+        If the pattern is a class then this will return a list containing the
+        values associated with all members of that class.
+
+        If the pattern is a package then this will return a list containing the
+        values associated with all the members of all the classes in that
+        package and sub-packages.
+
+        If the pattern ends with "*" then the preceding part is treated as a
+        package and this will return a list containing the values associated
+        with all the members of all the classes in that package.
+
+        If the pattern ends with "**" then the preceding part is treated
+        as a package and this will return a list containing the values
+        associated with all the members of all the classes in that package and
+        all sub-packages.
+
+        :param pattern: the pattern which could be a complete signature or a
+        class, or package wildcard.
+        :return: an iterable containing all the values associated with the
+        pattern.
+        """
+        elements = self.signatureToElements(pattern)
+        node = self
+        # Include all values from this node and all its children.
+        selector = lambda x : True
+        lastElement = elements[-1]
+        if lastElement == "*" or lastElement == "**":
+            elements = elements[:-1]
+            if lastElement == "*":
+                # Do not include values from sub-packages.
+                selector = lambda x : not x.startswith("package:")
+        for element in elements:
+            if element in node.nodes:
+                node = node.nodes[element]
+            else:
+                return []
+        return chain.from_iterable(node.values(selector))
+
+    def values(self, selector):
+        """
+        :param selector: a function that can be applied to a key in the nodes
+        attribute to determine whether to return its values.
+        :return: A list of iterables of all the values associated with this
+        node and its children.
+        """
+        values = []
+        self.appendValues(values, selector)
+        return values
+
+    def appendValues(self, values, selector):
+        """
+        Append the values associated with this node and its children to the
+        list.
+
+        For each item (key, child) in nodes the child node's values are returned
+        if and only if the selector returns True when called on its key. A child
+        node's values are all the values associated with it and all its
+        descendant nodes.
+
+        :param selector: a function that can be applied to a key in the nodes
+        attribute to determine whether to return its values.
+        :param values: a list of a iterables of values.
+        """
+        for key, node in self.nodes.items():
+            if selector(key):
+                node.appendValues(values, lambda x : True)
+
+class Leaf:
+    """
+    A leaf of the trie
+
+    Attributes:
+        value: the value associated with this leaf.
+    """
+    def __init__(self, value):
+        self.value = value
+
+    def values(self, selector):
+        """
+        :return: A list of a list of the value associated with this node.
+        """
+        return [[self.value]]
+
+    def appendValues(self, values, selector):
+        """
+        Appends a list of the value associated with this node to the list.
+        :param values: a list of a iterables of values.
+        """
+        values.append([self.value])
+
 def dict_reader(input):
     return csv.DictReader(input, delimiter=',', quotechar='|', fieldnames=['signature'])
 
-def extract_subset_from_monolithic_flags_as_dict_from_file(monolithicFlagsDict, patternsFile):
+def read_flag_trie_from_file(file):
+    with open(file, 'r') as stream:
+        return read_flag_trie_from_stream(stream)
+
+def read_flag_trie_from_stream(stream):
+    trie = InteriorNode()
+    reader = dict_reader(stream)
+    for row in reader:
+        signature = row['signature']
+        trie.add(signature, row)
+    return trie
+
+def extract_subset_from_monolithic_flags_as_dict_from_file(monolithicTrie, patternsFile):
     """
     Extract a subset of flags from the dict containing all the monolithic flags.
 
@@ -34,21 +266,24 @@
     :return: the dict from signature to row.
     """
     with open(patternsFile, 'r') as stream:
-        return extract_subset_from_monolithic_flags_as_dict_from_stream(monolithicFlagsDict, stream)
+        return extract_subset_from_monolithic_flags_as_dict_from_stream(monolithicTrie, stream)
 
-def extract_subset_from_monolithic_flags_as_dict_from_stream(monolithicFlagsDict, stream):
+def extract_subset_from_monolithic_flags_as_dict_from_stream(monolithicTrie, stream):
     """
-    Extract a subset of flags from the dict containing all the monolithic flags.
+    Extract a subset of flags from the trie containing all the monolithic flags.
 
-    :param monolithicFlagsDict: the dict containing all the monolithic flags.
+    :param monolithicTrie: the trie containing all the monolithic flags.
     :param stream: a stream containing a list of signature patterns that define
     the subset.
     :return: the dict from signature to row.
     """
     dict = {}
-    for signature in stream:
-        signature = signature.rstrip()
-        dict[signature] = monolithicFlagsDict.get(signature, {})
+    for pattern in stream:
+        pattern = pattern.rstrip()
+        rows = monolithicTrie.getMatchingRows(pattern)
+        for row in rows:
+            signature = row['signature']
+            dict[signature] = row
     return dict
 
 def read_signature_csv_from_stream_as_dict(stream):
@@ -108,20 +343,20 @@
     args_parser.add_argument('modularFlags', nargs=argparse.REMAINDER, help='Flags produced by individual bootclasspath_fragment modules')
     args = args_parser.parse_args(argv[1:])
 
-    # Read in the monolithic flags into a dict indexed by signature
+    # Read in all the flags into the trie
     monolithicFlagsPath = args.monolithicFlags
-    monolithicFlagsDict = read_signature_csv_from_file_as_dict(monolithicFlagsPath)
+    monolithicTrie = read_flag_trie_from_file(monolithicFlagsPath)
 
     # For each subset specified on the command line, create dicts for the flags
-    # provided by the subset and the corresponding flags from the complete set of
-    # flags and compare them.
+    # provided by the subset and the corresponding flags from the complete set
+    # of flags and compare them.
     failed = False
     for modularPair in args.modularFlags:
         parts = modularPair.split(":")
         modularFlagsPath = parts[0]
         modularPatternsPath = parts[1]
         modularFlagsDict = read_signature_csv_from_file_as_dict(modularFlagsPath)
-        monolithicFlagsSubsetDict = extract_subset_from_monolithic_flags_as_dict_from_file(monolithicFlagsDict, modularPatternsPath)
+        monolithicFlagsSubsetDict = extract_subset_from_monolithic_flags_as_dict_from_file(monolithicTrie, modularPatternsPath)
         mismatchingSignatures = compare_signature_flags(monolithicFlagsSubsetDict, modularFlagsDict)
         if mismatchingSignatures:
             failed = True
diff --git a/scripts/hiddenapi/verify_overlaps_test.py b/scripts/hiddenapi/verify_overlaps_test.py
index b6d5fa3..7477254 100755
--- a/scripts/hiddenapi/verify_overlaps_test.py
+++ b/scripts/hiddenapi/verify_overlaps_test.py
@@ -20,8 +20,52 @@
 
 from verify_overlaps import *
 
+class TestSignatureToElements(unittest.TestCase):
+
+    def signatureToElements(self, signature):
+        return InteriorNode().signatureToElements(signature)
+
+    def test_signatureToElements_1(self):
+        expected = [
+            'package:java',
+            'package:lang',
+            'class:ProcessBuilder',
+            'class:Redirect',
+            'class:1',
+            'member:<init>()V',
+        ]
+        self.assertEqual(expected, self.signatureToElements(
+            "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V"))
+
+    def test_signatureToElements_2(self):
+        expected = [
+            'package:java',
+            'package:lang',
+            'class:Object',
+            'member:hashCode()I',
+        ]
+        self.assertEqual(expected, self.signatureToElements(
+            "Ljava/lang/Object;->hashCode()I"))
+
+    def test_signatureToElements_3(self):
+        expected = [
+            'package:java',
+            'package:lang',
+            'class:CharSequence',
+            'class:',
+            'class:ExternalSyntheticLambda0',
+            'member:<init>(Ljava/lang/CharSequence;)V',
+        ]
+        self.assertEqual(expected, self.signatureToElements(
+            "Ljava/lang/CharSequence$$ExternalSyntheticLambda0;"
+            "-><init>(Ljava/lang/CharSequence;)V"))
+
 class TestDetectOverlaps(unittest.TestCase):
 
+    def read_flag_trie_from_string(self, csv):
+        with io.StringIO(csv) as f:
+            return read_flag_trie_from_stream(f)
+
     def read_signature_csv_from_string_as_dict(self, csv):
         with io.StringIO(csv) as f:
             return read_signature_csv_from_stream_as_dict(f)
@@ -33,13 +77,14 @@
     extractInput = '''
 Ljava/lang/Object;->hashCode()I,public-api,system-api,test-api
 Ljava/lang/Object;->toString()Ljava/lang/String;,blocked
+Ljava/util/zip/ZipFile;-><clinit>()V,blocked
+Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;,blocked
+Ljava/lang/Character;->serialVersionUID:J,sdk
+Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V,blocked
 '''
 
-    def test_extract_subset(self):
-        monolithic = self.read_signature_csv_from_string_as_dict(TestDetectOverlaps.extractInput)
-        modular = self.read_signature_csv_from_string_as_dict('''
-Ljava/lang/Object;->hashCode()I,public-api,system-api,test-api
-''')
+    def test_extract_subset_signature(self):
+        monolithic = self.read_flag_trie_from_string(TestDetectOverlaps.extractInput)
 
         patterns = 'Ljava/lang/Object;->hashCode()I'
 
@@ -52,6 +97,144 @@
         }
         self.assertEqual(expected, subset)
 
+    def test_extract_subset_class(self):
+        monolithic = self.read_flag_trie_from_string(TestDetectOverlaps.extractInput)
+
+        patterns = 'java/lang/Object'
+
+        subset = self.extract_subset_from_monolithic_flags_as_dict_from_string(monolithic, patterns)
+        expected = {
+            'Ljava/lang/Object;->hashCode()I': {
+                None: ['public-api', 'system-api', 'test-api'],
+                'signature': 'Ljava/lang/Object;->hashCode()I',
+            },
+            'Ljava/lang/Object;->toString()Ljava/lang/String;': {
+                None: ['blocked'],
+                'signature': 'Ljava/lang/Object;->toString()Ljava/lang/String;',
+            },
+        }
+        self.assertEqual(expected, subset)
+
+    def test_extract_subset_outer_class(self):
+        monolithic = self.read_flag_trie_from_string(TestDetectOverlaps.extractInput)
+
+        patterns = 'java/lang/Character'
+
+        subset = self.extract_subset_from_monolithic_flags_as_dict_from_string(monolithic, patterns)
+        expected = {
+            'Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;': {
+                None: ['blocked'],
+                'signature': 'Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;',
+            },
+            'Ljava/lang/Character;->serialVersionUID:J': {
+                None: ['sdk'],
+                'signature': 'Ljava/lang/Character;->serialVersionUID:J',
+            },
+        }
+        self.assertEqual(expected, subset)
+
+    def test_extract_subset_nested_class(self):
+        monolithic = self.read_flag_trie_from_string(TestDetectOverlaps.extractInput)
+
+        patterns = 'java/lang/Character$UnicodeScript'
+
+        subset = self.extract_subset_from_monolithic_flags_as_dict_from_string(monolithic, patterns)
+        expected = {
+            'Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;': {
+                None: ['blocked'],
+                'signature': 'Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;',
+            },
+        }
+        self.assertEqual(expected, subset)
+
+    def test_extract_subset_package(self):
+        monolithic = self.read_flag_trie_from_string(TestDetectOverlaps.extractInput)
+
+        patterns = 'java/lang/*'
+
+        subset = self.extract_subset_from_monolithic_flags_as_dict_from_string(monolithic, patterns)
+        expected = {
+            'Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;': {
+                None: ['blocked'],
+                'signature': 'Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;',
+            },
+            'Ljava/lang/Character;->serialVersionUID:J': {
+                None: ['sdk'],
+                'signature': 'Ljava/lang/Character;->serialVersionUID:J',
+            },
+            'Ljava/lang/Object;->hashCode()I': {
+                None: ['public-api', 'system-api', 'test-api'],
+                'signature': 'Ljava/lang/Object;->hashCode()I',
+            },
+            'Ljava/lang/Object;->toString()Ljava/lang/String;': {
+                None: ['blocked'],
+                'signature': 'Ljava/lang/Object;->toString()Ljava/lang/String;',
+            },
+            'Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V': {
+                None: ['blocked'],
+                'signature': 'Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V',
+            },
+        }
+        self.assertEqual(expected, subset)
+
+    def test_extract_subset_recursive_package(self):
+        monolithic = self.read_flag_trie_from_string(TestDetectOverlaps.extractInput)
+
+        patterns = 'java/**'
+
+        subset = self.extract_subset_from_monolithic_flags_as_dict_from_string(monolithic, patterns)
+        expected = {
+            'Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;': {
+                None: ['blocked'],
+                'signature': 'Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;',
+            },
+            'Ljava/lang/Character;->serialVersionUID:J': {
+                None: ['sdk'],
+                'signature': 'Ljava/lang/Character;->serialVersionUID:J',
+            },
+            'Ljava/lang/Object;->hashCode()I': {
+                None: ['public-api', 'system-api', 'test-api'],
+                'signature': 'Ljava/lang/Object;->hashCode()I',
+            },
+            'Ljava/lang/Object;->toString()Ljava/lang/String;': {
+                None: ['blocked'],
+                'signature': 'Ljava/lang/Object;->toString()Ljava/lang/String;',
+            },
+            'Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V': {
+                None: ['blocked'],
+                'signature': 'Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V',
+            },
+            'Ljava/util/zip/ZipFile;-><clinit>()V': {
+                None: ['blocked'],
+                'signature': 'Ljava/util/zip/ZipFile;-><clinit>()V',
+            },
+        }
+        self.assertEqual(expected, subset)
+
+    def test_extract_subset_invalid_pattern_wildcard_and_member(self):
+        monolithic = self.read_flag_trie_from_string(TestDetectOverlaps.extractInput)
+
+        patterns = 'Ljava/lang/*;->hashCode()I'
+
+        with self.assertRaises(Exception) as context:
+            self.extract_subset_from_monolithic_flags_as_dict_from_string(monolithic, patterns)
+        self.assertTrue("contains wildcard * and member signature hashCode()I" in str(context.exception))
+
+    def test_read_trie_duplicate(self):
+        with self.assertRaises(Exception) as context:
+            self.read_flag_trie_from_string('''
+Ljava/lang/Object;->hashCode()I,public-api,system-api,test-api
+Ljava/lang/Object;->hashCode()I,blocked
+''')
+        self.assertTrue("Duplicate signature: Ljava/lang/Object;->hashCode()I" in str(context.exception))
+
+    def test_read_trie_missing_member(self):
+        with self.assertRaises(Exception) as context:
+            self.read_flag_trie_from_string('''
+Ljava/lang/Object,public-api,system-api,test-api
+''')
+        self.assertTrue("Invalid signature: Ljava/lang/Object, does not identify a specific member" in str(context.exception))
+
     def test_match(self):
         monolithic = self.read_signature_csv_from_string_as_dict('''
 Ljava/lang/Object;->hashCode()I,public-api,system-api,test-api