Add type information to symbolfile and ndkstubgen.

Test: mypy symbolfile
Test: pytest
Bug: None
Change-Id: I6b1045d315e5a10e699d31de9fafc084d82768b2
diff --git a/cc/symbolfile/__init__.py b/cc/symbolfile/__init__.py
index faa3823..5678e7d 100644
--- a/cc/symbolfile/__init__.py
+++ b/cc/symbolfile/__init__.py
@@ -14,15 +14,31 @@
 # limitations under the License.
 #
 """Parser for Android's version script information."""
+from dataclasses import dataclass
 import logging
 import re
+from typing import (
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    NewType,
+    Optional,
+    TextIO,
+    Tuple,
+)
+
+
+ApiMap = Mapping[str, int]
+Arch = NewType('Arch', str)
+Tag = NewType('Tag', str)
 
 
 ALL_ARCHITECTURES = (
-    'arm',
-    'arm64',
-    'x86',
-    'x86_64',
+    Arch('arm'),
+    Arch('arm64'),
+    Arch('x86'),
+    Arch('x86_64'),
 )
 
 
@@ -30,18 +46,36 @@
 FUTURE_API_LEVEL = 10000
 
 
-def logger():
+def logger() -> logging.Logger:
     """Return the main logger for this module."""
     return logging.getLogger(__name__)
 
 
-def get_tags(line):
+@dataclass
+class Symbol:
+    """A symbol definition from a symbol file."""
+
+    name: str
+    tags: List[Tag]
+
+
+@dataclass
+class Version:
+    """A version block of a symbol file."""
+
+    name: str
+    base: Optional[str]
+    tags: List[Tag]
+    symbols: List[Symbol]
+
+
+def get_tags(line: str) -> List[Tag]:
     """Returns a list of all tags on this line."""
     _, _, all_tags = line.strip().partition('#')
-    return [e for e in re.split(r'\s+', all_tags) if e.strip()]
+    return [Tag(e) for e in re.split(r'\s+', all_tags) if e.strip()]
 
 
-def is_api_level_tag(tag):
+def is_api_level_tag(tag: Tag) -> bool:
     """Returns true if this tag has an API level that may need decoding."""
     if tag.startswith('introduced='):
         return True
@@ -52,7 +86,7 @@
     return False
 
 
-def decode_api_level(api, api_map):
+def decode_api_level(api: str, api_map: ApiMap) -> int:
     """Decodes the API level argument into the API level number.
 
     For the average case, this just decodes the integer value from the string,
@@ -70,12 +104,13 @@
     return api_map[api]
 
 
-def decode_api_level_tags(tags, api_map):
+def decode_api_level_tags(tags: Iterable[Tag], api_map: ApiMap) -> List[Tag]:
     """Decodes API level code names in a list of tags.
 
     Raises:
         ParseError: An unknown version name was found in a tag.
     """
+    decoded_tags = list(tags)
     for idx, tag in enumerate(tags):
         if not is_api_level_tag(tag):
             continue
@@ -83,13 +118,13 @@
 
         try:
             decoded = str(decode_api_level(value, api_map))
-            tags[idx] = '='.join([name, decoded])
+            decoded_tags[idx] = Tag('='.join([name, decoded]))
         except KeyError:
-            raise ParseError('Unknown version name in tag: {}'.format(tag))
-    return tags
+            raise ParseError(f'Unknown version name in tag: {tag}')
+    return decoded_tags
 
 
-def split_tag(tag):
+def split_tag(tag: Tag) -> Tuple[str, str]:
     """Returns a key/value tuple of the tag.
 
     Raises:
@@ -103,7 +138,7 @@
     return key, value
 
 
-def get_tag_value(tag):
+def get_tag_value(tag: Tag) -> str:
     """Returns the value of a key/value tag.
 
     Raises:
@@ -114,12 +149,13 @@
     return split_tag(tag)[1]
 
 
-def version_is_private(version):
+def version_is_private(version: str) -> bool:
     """Returns True if the version name should be treated as private."""
     return version.endswith('_PRIVATE') or version.endswith('_PLATFORM')
 
 
-def should_omit_version(version, arch, api, llndk, apex):
+def should_omit_version(version: Version, arch: Arch, api: int, llndk: bool,
+                        apex: bool) -> bool:
     """Returns True if the version section should be ommitted.
 
     We want to omit any sections that do not have any symbols we'll have in the
@@ -145,7 +181,8 @@
     return False
 
 
-def should_omit_symbol(symbol, arch, api, llndk, apex):
+def should_omit_symbol(symbol: Symbol, arch: Arch, api: int, llndk: bool,
+                       apex: bool) -> bool:
     """Returns True if the symbol should be omitted."""
     no_llndk_no_apex = 'llndk' not in symbol.tags and 'apex' not in symbol.tags
     keep = no_llndk_no_apex or \
@@ -160,7 +197,7 @@
     return False
 
 
-def symbol_in_arch(tags, arch):
+def symbol_in_arch(tags: Iterable[Tag], arch: Arch) -> bool:
     """Returns true if the symbol is present for the given architecture."""
     has_arch_tags = False
     for tag in tags:
@@ -175,7 +212,7 @@
     return not has_arch_tags
 
 
-def symbol_in_api(tags, arch, api):
+def symbol_in_api(tags: Iterable[Tag], arch: Arch, api: int) -> bool:
     """Returns true if the symbol is present for the given API level."""
     introduced_tag = None
     arch_specific = False
@@ -197,7 +234,7 @@
     return api >= int(get_tag_value(introduced_tag))
 
 
-def symbol_versioned_in_api(tags, api):
+def symbol_versioned_in_api(tags: Iterable[Tag], api: int) -> bool:
     """Returns true if the symbol should be versioned for the given API.
 
     This models the `versioned=API` tag. This should be a very uncommonly
@@ -223,68 +260,40 @@
 
 class MultiplyDefinedSymbolError(RuntimeError):
     """A symbol name was multiply defined."""
-    def __init__(self, multiply_defined_symbols):
-        super(MultiplyDefinedSymbolError, self).__init__(
+    def __init__(self, multiply_defined_symbols: Iterable[str]) -> None:
+        super().__init__(
             'Version script contains multiple definitions for: {}'.format(
                 ', '.join(multiply_defined_symbols)))
         self.multiply_defined_symbols = multiply_defined_symbols
 
 
-class Version:
-    """A version block of a symbol file."""
-    def __init__(self, name, base, tags, symbols):
-        self.name = name
-        self.base = base
-        self.tags = tags
-        self.symbols = symbols
-
-    def __eq__(self, other):
-        if self.name != other.name:
-            return False
-        if self.base != other.base:
-            return False
-        if self.tags != other.tags:
-            return False
-        if self.symbols != other.symbols:
-            return False
-        return True
-
-
-class Symbol:
-    """A symbol definition from a symbol file."""
-    def __init__(self, name, tags):
-        self.name = name
-        self.tags = tags
-
-    def __eq__(self, other):
-        return self.name == other.name and set(self.tags) == set(other.tags)
-
-
 class SymbolFileParser:
     """Parses NDK symbol files."""
-    def __init__(self, input_file, api_map, arch, api, llndk, apex):
+    def __init__(self, input_file: TextIO, api_map: ApiMap, arch: Arch,
+                 api: int, llndk: bool, apex: bool) -> None:
         self.input_file = input_file
         self.api_map = api_map
         self.arch = arch
         self.api = api
         self.llndk = llndk
         self.apex = apex
-        self.current_line = None
+        self.current_line: Optional[str] = None
 
-    def parse(self):
+    def parse(self) -> List[Version]:
         """Parses the symbol file and returns a list of Version objects."""
         versions = []
         while self.next_line() != '':
+            assert self.current_line is not None
             if '{' in self.current_line:
                 versions.append(self.parse_version())
             else:
                 raise ParseError(
-                    'Unexpected contents at top level: ' + self.current_line)
+                    f'Unexpected contents at top level: {self.current_line}')
 
         self.check_no_duplicate_symbols(versions)
         return versions
 
-    def check_no_duplicate_symbols(self, versions):
+    def check_no_duplicate_symbols(self, versions: Iterable[Version]) -> None:
         """Raises errors for multiply defined symbols.
 
         This situation is the normal case when symbol versioning is actually
@@ -312,12 +321,13 @@
             raise MultiplyDefinedSymbolError(
                 sorted(list(multiply_defined_symbols)))
 
-    def parse_version(self):
+    def parse_version(self) -> Version:
         """Parses a single version section and returns a Version object."""
+        assert self.current_line is not None
         name = self.current_line.split('{')[0].strip()
         tags = get_tags(self.current_line)
         tags = decode_api_level_tags(tags, self.api_map)
-        symbols = []
+        symbols: List[Symbol] = []
         global_scope = True
         cpp_symbols = False
         while self.next_line() != '':
@@ -333,9 +343,7 @@
                     cpp_symbols = False
                 else:
                     base = base.rstrip(';').rstrip()
-                    if base == '':
-                        base = None
-                    return Version(name, base, tags, symbols)
+                    return Version(name, base or None, tags, symbols)
             elif 'extern "C++" {' in self.current_line:
                 cpp_symbols = True
             elif not cpp_symbols and ':' in self.current_line:
@@ -354,8 +362,9 @@
                 pass
         raise ParseError('Unexpected EOF in version block.')
 
-    def parse_symbol(self):
+    def parse_symbol(self) -> Symbol:
         """Parses a single symbol line and returns a Symbol object."""
+        assert self.current_line is not None
         if ';' not in self.current_line:
             raise ParseError(
                 'Expected ; to terminate symbol: ' + self.current_line)
@@ -368,7 +377,7 @@
         tags = decode_api_level_tags(tags, self.api_map)
         return Symbol(name, tags)
 
-    def next_line(self):
+    def next_line(self) -> str:
         """Returns the next non-empty non-comment line.
 
         A return value of '' indicates EOF.