Merge "Extract a subset of the monolithic flags for comparison"
diff --git a/scripts/hiddenapi/verify_overlaps.py b/scripts/hiddenapi/verify_overlaps.py
index b5ca30c..fe4e359 100755
--- a/scripts/hiddenapi/verify_overlaps.py
+++ b/scripts/hiddenapi/verify_overlaps.py
@@ -19,10 +19,24 @@
 
 import argparse
 import csv
+from itertools import chain
 
 def dict_reader(input):
     return csv.DictReader(input, delimiter=',', quotechar='|', fieldnames=['signature'])
 
+def extract_subset_from_monolithic_flags_as_dict(monolithicFlagsDict, signatures):
+    """
+    Extract a subset of flags from the dict containing all the monolithic flags.
+
+    :param monolithicFlagsDict: the dict containing all the monolithic flags.
+    :param signatures: a list of signature that define the subset.
+    :return: the dict from signature to row.
+    """
+    dict = {}
+    for signature in signatures:
+        dict[signature] = monolithicFlagsDict.get(signature, {})
+    return dict
+
 def read_signature_csv_from_stream_as_dict(stream):
     """
     Read the csv contents from the stream into a dict. The first column is assumed to be the
@@ -62,10 +76,14 @@
     modular dict, and monolithic dict respectively.
     """
     mismatchingSignatures = []
-    for signature, modularRow in modularFlagsDict.items():
-        modularFlags = modularRow.get(None, [])
+    # Create a sorted set of all the signatures from both the monolithic and
+    # modular dicts.
+    allSignatures = sorted(set(chain(monolithicFlagsDict.keys(), modularFlagsDict.keys())))
+    for signature in allSignatures:
         monolithicRow = monolithicFlagsDict.get(signature, {})
         monolithicFlags = monolithicRow.get(None, [])
+        modularRow = modularFlagsDict.get(signature, {})
+        modularFlags = modularRow.get(None, [])
         if monolithicFlags != modularFlags:
             mismatchingSignatures.append((signature, modularFlags, monolithicFlags))
     return mismatchingSignatures
@@ -80,10 +98,14 @@
     monolithicFlagsPath = args.monolithicFlags
     monolithicFlagsDict = read_signature_csv_from_file_as_dict(monolithicFlagsPath)
 
+    # For each subset specified on the command line, create dicts for the flags
+    # provided by the subset and the corresponding flags from the complete set of
+    # flags and compare them.
     failed = False
     for modularFlagsPath in args.modularFlags:
         modularFlagsDict = read_signature_csv_from_file_as_dict(modularFlagsPath)
-        mismatchingSignatures = compare_signature_flags(monolithicFlagsDict, modularFlagsDict)
+        monolithicFlagsSubsetDict = extract_subset_from_monolithic_flags_as_dict(monolithicFlagsDict, modularFlagsDict.keys())
+        mismatchingSignatures = compare_signature_flags(monolithicFlagsSubsetDict, modularFlagsDict)
         if mismatchingSignatures:
             failed = True
             print("ERROR: Hidden API flags are inconsistent:")
diff --git a/scripts/hiddenapi/verify_overlaps_test.py b/scripts/hiddenapi/verify_overlaps_test.py
index 1248890..fdb7fa2 100755
--- a/scripts/hiddenapi/verify_overlaps_test.py
+++ b/scripts/hiddenapi/verify_overlaps_test.py
@@ -26,10 +26,28 @@
         with io.StringIO(csv) as f:
             return read_signature_csv_from_stream_as_dict(f)
 
+    extractInput = '''
+Ljava/lang/Object;->hashCode()I,public-api,system-api,test-api
+Ljava/lang/Object;->toString()Ljava/lang/String;,blocked
+'''
+
+    def test_extract_subset(self):
+        monolithic = self.read_signature_csv_from_string_as_dict(TestDetectOverlaps.extractInput)
+        modular = self.read_signature_csv_from_string_as_dict('''
+Ljava/lang/Object;->hashCode()I,public-api,system-api,test-api
+''')
+        subset = extract_subset_from_monolithic_flags_as_dict(monolithic, modular.keys())
+        expected = {
+            'Ljava/lang/Object;->hashCode()I': {
+                None: ['public-api', 'system-api', 'test-api'],
+                'signature': 'Ljava/lang/Object;->hashCode()I',
+            },
+        }
+        self.assertEqual(expected, subset)
+
     def test_match(self):
         monolithic = self.read_signature_csv_from_string_as_dict('''
 Ljava/lang/Object;->hashCode()I,public-api,system-api,test-api
-Ljava/lang/Object;->toString()Ljava/lang/String;,blocked
 ''')
         modular = self.read_signature_csv_from_string_as_dict('''
 Ljava/lang/Object;->hashCode()I,public-api,system-api,test-api
@@ -58,7 +76,6 @@
 
     def test_mismatch_monolithic_blocked(self):
         monolithic = self.read_signature_csv_from_string_as_dict('''
-Ljava/lang/Object;->hashCode()I,public-api,system-api,test-api
 Ljava/lang/Object;->toString()Ljava/lang/String;,blocked
 ''')
         modular = self.read_signature_csv_from_string_as_dict('''
@@ -76,7 +93,6 @@
 
     def test_mismatch_modular_blocked(self):
         monolithic = self.read_signature_csv_from_string_as_dict('''
-Ljava/lang/Object;->hashCode()I,public-api,system-api,test-api
 Ljava/lang/Object;->toString()Ljava/lang/String;,public-api,system-api,test-api
 ''')
         modular = self.read_signature_csv_from_string_as_dict('''
@@ -93,9 +109,7 @@
         self.assertEqual(expected, mismatches)
 
     def test_missing_from_monolithic(self):
-        monolithic = self.read_signature_csv_from_string_as_dict('''
-Ljava/lang/Object;->hashCode()I,public-api,system-api,test-api
-''')
+        monolithic = self.read_signature_csv_from_string_as_dict('')
         modular = self.read_signature_csv_from_string_as_dict('''
 Ljava/lang/Object;->toString()Ljava/lang/String;,public-api,system-api,test-api
 ''')
@@ -110,27 +124,33 @@
         self.assertEqual(expected, mismatches)
 
     def test_missing_from_modular(self):
-        # The modular dict defines the set of signatures to compare so an entry
-        # in the monolithic dict that does not have a corresponding entry in the
-        # modular dict is ignored.
         monolithic = self.read_signature_csv_from_string_as_dict('''
 Ljava/lang/Object;->hashCode()I,public-api,system-api,test-api
 ''')
         modular = {}
         mismatches = compare_signature_flags(monolithic, modular)
-        expected = []
+        expected = [
+            (
+                'Ljava/lang/Object;->hashCode()I',
+                [],
+                ['public-api', 'system-api', 'test-api'],
+            ),
+        ]
         self.assertEqual(expected, mismatches)
 
     def test_blocked_missing_from_modular(self):
-        # The modular dict defines the set of signatures to compare so an entry
-        # in the monolithic dict that does not have a corresponding entry in the
-        # modular dict is ignored.
         monolithic = self.read_signature_csv_from_string_as_dict('''
 Ljava/lang/Object;->hashCode()I,blocked
 ''')
         modular = {}
         mismatches = compare_signature_flags(monolithic, modular)
-        expected = []
+        expected = [
+            (
+                'Ljava/lang/Object;->hashCode()I',
+                [],
+                ['blocked'],
+            ),
+        ]
         self.assertEqual(expected, mismatches)
 
 if __name__ == '__main__':