Merge "Allow traversal over the trie structure"
diff --git a/scripts/hiddenapi/signature_trie.py b/scripts/hiddenapi/signature_trie.py
index 5871834..e813a97 100644
--- a/scripts/hiddenapi/signature_trie.py
+++ b/scripts/hiddenapi/signature_trie.py
@@ -22,6 +22,19 @@
 
 @dataclasses.dataclass()
 class Node:
+    """A node in the signature trie."""
+
+    # The type of the node.
+    #
+    # Leaf nodes are of type "member".
+    # Interior nodes can be either "package", or "class".
+    type: str
+
+    # The selector of the node.
+    #
+    # That is a string that can be used to select the node, e.g. in a pattern
+    # that is passed to InteriorNode.get_matching_rows().
+    selector: str
 
     def values(self, selector):
         """Get the values from a set of selected nodes.
@@ -48,6 +61,10 @@
         """
         raise NotImplementedError("Please Implement this method")
 
+    def child_nodes(self):
+        """Get an iterable of the child nodes of this node."""
+        raise NotImplementedError("Please Implement this method")
+
 
 # pylint: disable=line-too-long
 @dataclasses.dataclass()
@@ -173,22 +190,68 @@
         element_type, _ = InteriorNode.split_element(element)
         return element_type
 
-    def add(self, signature, value):
+    @staticmethod
+    def elements_to_selector(elements):
+        """Compute a selector for a set of elements.
+
+        A selector uniquely identifies a specific Node in the trie. It is
+        essentially a prefix of a signature (without the leading L).
+
+        e.g. a trie containing "Ljava/lang/Object;->String()Ljava/lang/String;"
+        would contain nodes with the following selectors:
+        * "java"
+        * "java/lang"
+        * "java/lang/Object"
+        * "java/lang/Object;->String()Ljava/lang/String;"
+        """
+        signature = ""
+        preceding_type = ""
+        for element in elements:
+            element_type, element_value = InteriorNode.split_element(element)
+            separator = ""
+            if element_type == "package":
+                separator = "/"
+            elif element_type == "class":
+                if preceding_type == "class":
+                    separator = "$"
+                else:
+                    separator = "/"
+            elif element_type == "wildcard":
+                separator = "/"
+            elif element_type == "member":
+                separator += ";->"
+
+            if signature:
+                signature += separator
+
+            signature += element_value
+
+            preceding_type = element_type
+
+        return signature
+
+    def add(self, signature, value, only_if_matches=False):
         """Associate the value with the specific signature.
 
         :param signature: the member signature
         :param value: the value to associated with the signature
+        :param only_if_matches: True if the value is added only if the signature
+             matches at least one of the existing top level packages.
         :return: n/a
         """
         # Split the signature into elements.
         elements = self.signature_to_elements(signature)
         # Find the Node associated with the deepest class.
         node = self
-        for element in elements[:-1]:
+        for index, element in enumerate(elements[:-1]):
             if element in node.nodes:
                 node = node.nodes[element]
+            elif only_if_matches and index == 0:
+                return
             else:
-                next_node = InteriorNode()
+                selector = self.elements_to_selector(elements[0:index + 1])
+                next_node = InteriorNode(
+                    type=InteriorNode.element_type(element), selector=selector)
                 node.nodes[element] = next_node
                 node = next_node
         # Add a Leaf containing the value and associate it with the member
@@ -201,7 +264,12 @@
                 "specific member")
         if last_element in node.nodes:
             raise Exception(f"Duplicate signature: {signature}")
-        node.nodes[last_element] = Leaf(value)
+        leaf = Leaf(
+            type=last_element_type,
+            selector=signature,
+            value=value,
+        )
+        node.nodes[last_element] = leaf
 
     def get_matching_rows(self, pattern):
         """Get the values (plural) associated with the pattern.
@@ -212,10 +280,6 @@
         If the pattern is a class then this will return a list containing the
         values associated with all members of that class.
 
-        If the pattern is a package then this will return a list containing the
-        values associated with all the members of all the classes in that
-        package and sub-packages.
-
         If the pattern ends with "*" then the preceding part is treated as a
         package and this will return a list containing the values associated
         with all the members of all the classes in that package.
@@ -261,6 +325,9 @@
             if selector(key):
                 node.append_values(values, lambda x: True)
 
+    def child_nodes(self):
+        return self.nodes.values()
+
 
 @dataclasses.dataclass()
 class Leaf(Node):
@@ -275,6 +342,9 @@
     def append_values(self, values, selector):
         values.append([self.value])
 
+    def child_nodes(self):
+        return []
+
 
 def signature_trie():
-    return InteriorNode()
+    return InteriorNode(type="root", selector="")
diff --git a/scripts/hiddenapi/signature_trie_test.py b/scripts/hiddenapi/signature_trie_test.py
index e9644ef..1295691 100755
--- a/scripts/hiddenapi/signature_trie_test.py
+++ b/scripts/hiddenapi/signature_trie_test.py
@@ -27,6 +27,10 @@
     def signature_to_elements(signature):
         return InteriorNode.signature_to_elements(signature)
 
+    @staticmethod
+    def elements_to_signature(elements):
+        return InteriorNode.elements_to_selector(elements)
+
     def test_nested_inner_classes(self):
         elements = [
             ("package", "java"),
@@ -38,6 +42,7 @@
         ]
         signature = "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V"
         self.assertEqual(elements, self.signature_to_elements(signature))
+        self.assertEqual(signature, "L" + self.elements_to_signature(elements))
 
     def test_basic_member(self):
         elements = [
@@ -48,6 +53,7 @@
         ]
         signature = "Ljava/lang/Object;->hashCode()I"
         self.assertEqual(elements, self.signature_to_elements(signature))
+        self.assertEqual(signature, "L" + self.elements_to_signature(elements))
 
     def test_double_dollar_class(self):
         elements = [
@@ -61,6 +67,7 @@
         signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0;" \
                     "-><init>(Ljava/lang/CharSequence;)V"
         self.assertEqual(elements, self.signature_to_elements(signature))
+        self.assertEqual(signature, "L" + self.elements_to_signature(elements))
 
     def test_no_member(self):
         elements = [
@@ -72,6 +79,7 @@
         ]
         signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0"
         self.assertEqual(elements, self.signature_to_elements(signature))
+        self.assertEqual(signature, "L" + self.elements_to_signature(elements))
 
     def test_wildcard(self):
         elements = [
@@ -81,6 +89,7 @@
         ]
         signature = "java/lang/*"
         self.assertEqual(elements, self.signature_to_elements(signature))
+        self.assertEqual(signature, self.elements_to_signature(elements))
 
     def test_recursive_wildcard(self):
         elements = [
@@ -90,6 +99,7 @@
         ]
         signature = "java/lang/**"
         self.assertEqual(elements, self.signature_to_elements(signature))
+        self.assertEqual(signature, self.elements_to_signature(elements))
 
     def test_no_packages_wildcard(self):
         elements = [
@@ -97,6 +107,7 @@
         ]
         signature = "*"
         self.assertEqual(elements, self.signature_to_elements(signature))
+        self.assertEqual(signature, self.elements_to_signature(elements))
 
     def test_no_packages_recursive_wildcard(self):
         elements = [
@@ -104,6 +115,7 @@
         ]
         signature = "**"
         self.assertEqual(elements, self.signature_to_elements(signature))
+        self.assertEqual(signature, self.elements_to_signature(elements))
 
     def test_invalid_no_class_or_wildcard(self):
         signature = "java/lang"
@@ -121,6 +133,7 @@
         ]
         signature = "Ljavax/crypto/extObjectInputStream"
         self.assertEqual(elements, self.signature_to_elements(signature))
+        self.assertEqual(signature, "L" + self.elements_to_signature(elements))
 
     def test_invalid_pattern_wildcard(self):
         pattern = "Ljava/lang/Class*"
@@ -200,6 +213,18 @@
             "Ljava/util/zip/ZipFile;-><clinit>()V",
         ])
 
+    def test_node_wildcard(self):
+        trie = self.read_trie()
+        node = list(trie.child_nodes())[0]
+        self.check_node_patterns(node, "**", [
+            "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
+            "Ljava/lang/Character;->serialVersionUID:J",
+            "Ljava/lang/Object;->hashCode()I",
+            "Ljava/lang/Object;->toString()Ljava/lang/String;",
+            "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
+            "Ljava/util/zip/ZipFile;-><clinit>()V",
+        ])
+
     # pylint: enable=line-too-long