Merge "Allow traversal over the trie structure"
diff --git a/scripts/hiddenapi/signature_trie.py b/scripts/hiddenapi/signature_trie.py
index 5871834..e813a97 100644
--- a/scripts/hiddenapi/signature_trie.py
+++ b/scripts/hiddenapi/signature_trie.py
@@ -22,6 +22,19 @@
@dataclasses.dataclass()
class Node:
+ """A node in the signature trie."""
+
+ # The type of the node.
+ #
+ # Leaf nodes are of type "member".
+ # Interior nodes can be either "package", or "class".
+ type: str
+
+ # The selector of the node.
+ #
+ # That is a string that can be used to select the node, e.g. in a pattern
+ # that is passed to InteriorNode.get_matching_rows().
+ selector: str
def values(self, selector):
"""Get the values from a set of selected nodes.
@@ -48,6 +61,10 @@
"""
raise NotImplementedError("Please Implement this method")
+ def child_nodes(self):
+ """Get an iterable of the child nodes of this node."""
+ raise NotImplementedError("Please Implement this method")
+
# pylint: disable=line-too-long
@dataclasses.dataclass()
@@ -173,22 +190,68 @@
element_type, _ = InteriorNode.split_element(element)
return element_type
- def add(self, signature, value):
+ @staticmethod
+ def elements_to_selector(elements):
+ """Compute a selector for a set of elements.
+
+ A selector uniquely identifies a specific Node in the trie. It is
+ essentially a prefix of a signature (without the leading L).
+
+ e.g. a trie containing "Ljava/lang/Object;->String()Ljava/lang/String;"
+ would contain nodes with the following selectors:
+ * "java"
+ * "java/lang"
+ * "java/lang/Object"
+ * "java/lang/Object;->String()Ljava/lang/String;"
+ """
+ signature = ""
+ preceding_type = ""
+ for element in elements:
+ element_type, element_value = InteriorNode.split_element(element)
+ separator = ""
+ if element_type == "package":
+ separator = "/"
+ elif element_type == "class":
+ if preceding_type == "class":
+ separator = "$"
+ else:
+ separator = "/"
+ elif element_type == "wildcard":
+ separator = "/"
+ elif element_type == "member":
+ separator += ";->"
+
+ if signature:
+ signature += separator
+
+ signature += element_value
+
+ preceding_type = element_type
+
+ return signature
+
+ def add(self, signature, value, only_if_matches=False):
"""Associate the value with the specific signature.
:param signature: the member signature
:param value: the value to associated with the signature
+ :param only_if_matches: True if the value is added only if the signature
+ matches at least one of the existing top level packages.
:return: n/a
"""
# Split the signature into elements.
elements = self.signature_to_elements(signature)
# Find the Node associated with the deepest class.
node = self
- for element in elements[:-1]:
+ for index, element in enumerate(elements[:-1]):
if element in node.nodes:
node = node.nodes[element]
+ elif only_if_matches and index == 0:
+ return
else:
- next_node = InteriorNode()
+ selector = self.elements_to_selector(elements[0:index + 1])
+ next_node = InteriorNode(
+ type=InteriorNode.element_type(element), selector=selector)
node.nodes[element] = next_node
node = next_node
# Add a Leaf containing the value and associate it with the member
@@ -201,7 +264,12 @@
"specific member")
if last_element in node.nodes:
raise Exception(f"Duplicate signature: {signature}")
- node.nodes[last_element] = Leaf(value)
+ leaf = Leaf(
+ type=last_element_type,
+ selector=signature,
+ value=value,
+ )
+ node.nodes[last_element] = leaf
def get_matching_rows(self, pattern):
"""Get the values (plural) associated with the pattern.
@@ -212,10 +280,6 @@
If the pattern is a class then this will return a list containing the
values associated with all members of that class.
- If the pattern is a package then this will return a list containing the
- values associated with all the members of all the classes in that
- package and sub-packages.
-
If the pattern ends with "*" then the preceding part is treated as a
package and this will return a list containing the values associated
with all the members of all the classes in that package.
@@ -261,6 +325,9 @@
if selector(key):
node.append_values(values, lambda x: True)
+ def child_nodes(self):
+ return self.nodes.values()
+
@dataclasses.dataclass()
class Leaf(Node):
@@ -275,6 +342,9 @@
def append_values(self, values, selector):
values.append([self.value])
+ def child_nodes(self):
+ return []
+
def signature_trie():
- return InteriorNode()
+ return InteriorNode(type="root", selector="")
diff --git a/scripts/hiddenapi/signature_trie_test.py b/scripts/hiddenapi/signature_trie_test.py
index e9644ef..1295691 100755
--- a/scripts/hiddenapi/signature_trie_test.py
+++ b/scripts/hiddenapi/signature_trie_test.py
@@ -27,6 +27,10 @@
def signature_to_elements(signature):
return InteriorNode.signature_to_elements(signature)
+ @staticmethod
+ def elements_to_signature(elements):
+ return InteriorNode.elements_to_selector(elements)
+
def test_nested_inner_classes(self):
elements = [
("package", "java"),
@@ -38,6 +42,7 @@
]
signature = "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V"
self.assertEqual(elements, self.signature_to_elements(signature))
+ self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_basic_member(self):
elements = [
@@ -48,6 +53,7 @@
]
signature = "Ljava/lang/Object;->hashCode()I"
self.assertEqual(elements, self.signature_to_elements(signature))
+ self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_double_dollar_class(self):
elements = [
@@ -61,6 +67,7 @@
signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0;" \
"-><init>(Ljava/lang/CharSequence;)V"
self.assertEqual(elements, self.signature_to_elements(signature))
+ self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_no_member(self):
elements = [
@@ -72,6 +79,7 @@
]
signature = "Ljava/lang/CharSequence$$ExternalSyntheticLambda0"
self.assertEqual(elements, self.signature_to_elements(signature))
+ self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_wildcard(self):
elements = [
@@ -81,6 +89,7 @@
]
signature = "java/lang/*"
self.assertEqual(elements, self.signature_to_elements(signature))
+ self.assertEqual(signature, self.elements_to_signature(elements))
def test_recursive_wildcard(self):
elements = [
@@ -90,6 +99,7 @@
]
signature = "java/lang/**"
self.assertEqual(elements, self.signature_to_elements(signature))
+ self.assertEqual(signature, self.elements_to_signature(elements))
def test_no_packages_wildcard(self):
elements = [
@@ -97,6 +107,7 @@
]
signature = "*"
self.assertEqual(elements, self.signature_to_elements(signature))
+ self.assertEqual(signature, self.elements_to_signature(elements))
def test_no_packages_recursive_wildcard(self):
elements = [
@@ -104,6 +115,7 @@
]
signature = "**"
self.assertEqual(elements, self.signature_to_elements(signature))
+ self.assertEqual(signature, self.elements_to_signature(elements))
def test_invalid_no_class_or_wildcard(self):
signature = "java/lang"
@@ -121,6 +133,7 @@
]
signature = "Ljavax/crypto/extObjectInputStream"
self.assertEqual(elements, self.signature_to_elements(signature))
+ self.assertEqual(signature, "L" + self.elements_to_signature(elements))
def test_invalid_pattern_wildcard(self):
pattern = "Ljava/lang/Class*"
@@ -200,6 +213,18 @@
"Ljava/util/zip/ZipFile;-><clinit>()V",
])
+ def test_node_wildcard(self):
+ trie = self.read_trie()
+ node = list(trie.child_nodes())[0]
+ self.check_node_patterns(node, "**", [
+ "Ljava/lang/Character$UnicodeScript;->of(I)Ljava/lang/Character$UnicodeScript;",
+ "Ljava/lang/Character;->serialVersionUID:J",
+ "Ljava/lang/Object;->hashCode()I",
+ "Ljava/lang/Object;->toString()Ljava/lang/String;",
+ "Ljava/lang/ProcessBuilder$Redirect$1;-><init>()V",
+ "Ljava/util/zip/ZipFile;-><clinit>()V",
+ ])
+
# pylint: enable=line-too-long