Refactor treble_sepolicy_tests.py

Introduce a new class TestPolicy to capture all the previous global
variables. This class contains the constructor and loading methods
(Get*) to load its internal state. The tests are modified to accept a
TestPolicy as first argument.

This commit is a no-op. There is no change to the tests.

`git show --ignore-space-change` can be used to skip over the
re-indentation due to the new class.

Bug: 269182257
Test: m selinux_policy (runs treble_sepolicy_tests against all
			compatible versions)
Test: Set DEBUG=True, compare generated scontexts. Identical.
Change-Id: Ia8da115dc1c0109b835e03b95da029b35712d251
diff --git a/tests/treble_sepolicy_tests.py b/tests/treble_sepolicy_tests.py
index b49f138..c966423 100644
--- a/tests/treble_sepolicy_tests.py
+++ b/tests/treble_sepolicy_tests.py
@@ -51,172 +51,166 @@
         self.entrypointpaths = []
         self.error = ""
 
-def PrintScontexts():
-    for d in sorted(alldomains.keys()):
-        sctx = alldomains[d]
-        print(d)
-        print("\tcoredomain="+str(sctx.coredomain))
-        print("\tappdomain="+str(sctx.appdomain))
-        print("\tfromSystem="+str(sctx.fromSystem))
-        print("\tfromVendor="+str(sctx.fromVendor))
-        print("\tattributes="+str(sctx.attributes))
-        print("\tentrypoints="+str(sctx.entrypoints))
-        print("\tentrypointpaths=")
-        if sctx.entrypointpaths is not None:
-            for path in sctx.entrypointpaths:
-                print("\t\t"+str(path))
 
-alldomains = {}
-coredomains = set()
-appdomains = set()
-vendordomains = set()
-pol = None
+class TestPolicy:
+    """A policy loaded in memory with its domains easily accessible."""
 
-# compat vars
-alltypes = set()
-oldalltypes = set()
-compatMapping = None
-pubtypes = set()
+    def __init__(self):
+        self.alldomains = {}
+        self.coredomains = set()
+        self.appdomains = set()
+        self.vendordomains = set()
+        self.pol = None
 
-# Distinguish between PRODUCT_FULL_TREBLE and PRODUCT_FULL_TREBLE_OVERRIDE
-FakeTreble = False
+        # compat vars
+        self.alltypes = set()
+        self.oldalltypes = set()
+        self.compatMapping = None
+        self.pubtypes = set()
 
-def GetAllDomains(pol):
-    global alldomains
-    for result in pol.QueryTypeAttribute("domain", True):
-        alldomains[result] = scontext()
+        # Distinguish between PRODUCT_FULL_TREBLE and PRODUCT_FULL_TREBLE_OVERRIDE
+        self.FakeTreble = False
 
-def GetAppDomains():
-    global appdomains
-    global alldomains
-    for d in alldomains:
-        # The application of the "appdomain" attribute is trusted because core
-        # selinux policy contains neverallow rules that enforce that only zygote
-        # and runas spawned processes may transition to processes that have
-        # the appdomain attribute.
-        if "appdomain" in alldomains[d].attributes:
-            alldomains[d].appdomain = True
-            appdomains.add(d)
+    def GetAllDomains(self):
+        for result in self.pol.QueryTypeAttribute("domain", True):
+            self.alldomains[result] = scontext()
 
-def GetCoreDomains():
-    global alldomains
-    global coredomains
-    for d in alldomains:
-        domain = alldomains[d]
-        # TestCoredomainViolations will verify if coredomain was incorrectly
-        # applied.
-        if "coredomain" in domain.attributes:
-            domain.coredomain = True
-            coredomains.add(d)
-        # check whether domains are executed off of /system or /vendor
-        if d in coredomainAllowlist:
-            continue
-        # TODO(b/153112003): add checks to prevent app domains from being
-        # incorrectly labeled as coredomain. Apps don't have entrypoints as
-        # they're always dynamically transitioned to by zygote.
-        if d in appdomains:
-            continue
-        # TODO(b/153112747): need to handle cases where there is a dynamic
-        # transition OR there happens to be no context in AOSP files.
-        if not domain.entrypointpaths:
-            continue
+    def GetAppDomains(self):
+        for d in self.alldomains:
+            # The application of the "appdomain" attribute is trusted because core
+            # selinux policy contains neverallow rules that enforce that only zygote
+            # and runas spawned processes may transition to processes that have
+            # the appdomain attribute.
+            if "appdomain" in self.alldomains[d].attributes:
+                self.alldomains[d].appdomain = True
+                self.appdomains.add(d)
 
-        for path in domain.entrypointpaths:
-            vendor = any(MatchPathPrefix(path, prefix) for prefix in
-                         ["/vendor", "/odm"])
-            system = any(MatchPathPrefix(path, prefix) for prefix in
-                         ["/init", "/system_ext", "/product" ])
+    def GetCoreDomains(self):
+        for d in self.alldomains:
+            domain = self.alldomains[d]
+            # TestCoredomainViolations will verify if coredomain was incorrectly
+            # applied.
+            if "coredomain" in domain.attributes:
+                domain.coredomain = True
+                self.coredomains.add(d)
+            # check whether domains are executed off of /system or /vendor
+            if d in coredomainAllowlist:
+                continue
+            # TODO(b/153112003): add checks to prevent app domains from being
+            # incorrectly labeled as coredomain. Apps don't have entrypoints as
+            # they're always dynamically transitioned to by zygote.
+            if d in self.appdomains:
+                continue
+            # TODO(b/153112747): need to handle cases where there is a dynamic
+            # transition OR there happens to be no context in AOSP files.
+            if not domain.entrypointpaths:
+                continue
 
-            # only mark entrypoint as system if it is not in legacy /system/vendor
-            if MatchPathPrefix(path, "/system/vendor"):
-                vendor = True
-            elif MatchPathPrefix(path, "/system"):
-                system = True
+            for path in domain.entrypointpaths:
+                vendor = any(MatchPathPrefix(path, prefix) for prefix in
+                             ["/vendor", "/odm"])
+                system = any(MatchPathPrefix(path, prefix) for prefix in
+                             ["/init", "/system_ext", "/product" ])
 
-            if not vendor and not system:
-                domain.error += "Unrecognized entrypoint for " + d + " at " + path + "\n"
+                # only mark entrypoint as system if it is not in legacy /system/vendor
+                if MatchPathPrefix(path, "/system/vendor"):
+                    vendor = True
+                elif MatchPathPrefix(path, "/system"):
+                    system = True
 
-            domain.fromSystem = domain.fromSystem or system
-            domain.fromVendor = domain.fromVendor or vendor
+                if not vendor and not system:
+                    domain.error += "Unrecognized entrypoint for " + d + " at " + path + "\n"
 
-###
-# Add the entrypoint type and path(s) to each domain.
-#
-def GetDomainEntrypoints(pol):
-    global alldomains
-    for x in pol.QueryExpandedTERule(tclass=set(["file"]), perms=set(["entrypoint"])):
-        if not x.sctx in alldomains:
-            continue
-        alldomains[x.sctx].entrypoints.append(str(x.tctx))
-        # postinstall_file represents a special case specific to A/B OTAs.
-        # Update_engine mounts a partition and relabels it postinstall_file.
-        # There is no file_contexts entry associated with postinstall_file
-        # so skip the lookup.
-        if x.tctx == "postinstall_file":
-            continue
-        entrypointpath = pol.QueryFc(x.tctx)
-        if not entrypointpath:
-            continue
-        alldomains[x.sctx].entrypointpaths.extend(entrypointpath)
-###
-# Get attributes associated with each domain
-#
-def GetAttributes(pol):
-    global alldomains
-    for domain in alldomains:
-        for result in pol.QueryTypeAttribute(domain, False):
-            alldomains[domain].attributes.add(result)
+                domain.fromSystem = domain.fromSystem or system
+                domain.fromVendor = domain.fromVendor or vendor
 
-def GetAllTypes(pol, oldpol):
-    global alltypes
-    global oldalltypes
-    alltypes = pol.GetAllTypes(False)
-    oldalltypes = oldpol.GetAllTypes(False)
+    ###
+    # Add the entrypoint type and path(s) to each domain.
+    #
+    def GetDomainEntrypoints(self):
+        for x in self.pol.QueryExpandedTERule(tclass=set(["file"]), perms=set(["entrypoint"])):
+            if not x.sctx in self.alldomains:
+                continue
+            self.alldomains[x.sctx].entrypoints.append(str(x.tctx))
+            # postinstall_file represents a special case specific to A/B OTAs.
+            # Update_engine mounts a partition and relabels it postinstall_file.
+            # There is no file_contexts entry associated with postinstall_file
+            # so skip the lookup.
+            if x.tctx == "postinstall_file":
+                continue
+            entrypointpath = self.pol.QueryFc(x.tctx)
+            if not entrypointpath:
+                continue
+            self.alldomains[x.sctx].entrypointpaths.extend(entrypointpath)
 
-def setup(pol):
-    GetAllDomains(pol)
-    GetAttributes(pol)
-    GetDomainEntrypoints(pol)
-    GetAppDomains()
-    GetCoreDomains()
+    ###
+    # Get attributes associated with each domain
+    #
+    def GetAttributes(self):
+        for domain in self.alldomains:
+            for result in self.pol.QueryTypeAttribute(domain, False):
+                self.alldomains[domain].attributes.add(result)
 
-# setup for the policy compatibility tests
-def compatSetup(pol, oldpol, mapping, types):
-    global compatMapping
-    global pubtypes
+    def setup(self, pol):
+        self.pol = pol
+        self.GetAllDomains()
+        self.GetAttributes()
+        self.GetDomainEntrypoints()
+        self.GetAppDomains()
+        self.GetCoreDomains()
 
-    GetAllTypes(pol, oldpol)
-    compatMapping = mapping
-    pubtypes = types
+    def GetAllTypes(self, basepol, oldpol):
+        self.alltypes = basepol.GetAllTypes(False)
+        self.oldalltypes = oldpol.GetAllTypes(False)
 
-def DomainsWithAttribute(attr):
-    global alldomains
-    domains = []
-    for domain in alldomains:
-        if attr in alldomains[domain].attributes:
-            domains.append(domain)
-    return domains
+    # setup for the policy compatibility tests
+    def compatSetup(self, basepol, oldpol, mapping, types):
+        self.GetAllTypes(basepol, oldpol)
+        self.compatMapping = mapping
+        self.pubtypes = types
+
+    def DomainsWithAttribute(self, attr):
+        domains = []
+        for domain in self.alldomains:
+            if attr in self.alldomains[domain].attributes:
+                domains.append(domain)
+        return domains
+
+    def PrintScontexts(self):
+        for d in sorted(self.alldomains.keys()):
+            sctx = self.alldomains[d]
+            print(d)
+            print("\tcoredomain="+str(sctx.coredomain))
+            print("\tappdomain="+str(sctx.appdomain))
+            print("\tfromSystem="+str(sctx.fromSystem))
+            print("\tfromVendor="+str(sctx.fromVendor))
+            print("\tattributes="+str(sctx.attributes))
+            print("\tentrypoints="+str(sctx.entrypoints))
+            print("\tentrypointpaths=")
+            if sctx.entrypointpaths is not None:
+                for path in sctx.entrypointpaths:
+                    print("\t\t"+str(path))
+
 
 #############################################################
 # Tests
 #############################################################
-def TestCoredomainViolations():
-    global alldomains
+def TestCoredomainViolations(test_policy):
     # verify that all domains launched from /system have the coredomain
     # attribute
     ret = ""
 
-    for d in alldomains:
-        domain = alldomains[d]
+    for d in test_policy.alldomains:
+        domain = test_policy.alldomains[d]
         if domain.fromSystem and domain.fromVendor:
             ret += "The following domain is system and vendor: " + d + "\n"
 
-    for domain in alldomains.values():
+    for domain in test_policy.alldomains.values():
         ret += domain.error
 
     violators = []
-    for d in alldomains:
-        domain = alldomains[d]
+    for d in test_policy.alldomains:
+        domain = test_policy.alldomains[d]
         if domain.fromSystem and "coredomain" not in domain.attributes:
                 violators.append(d);
     if len(violators) > 0:
@@ -228,8 +222,8 @@
     # verify that all domains launched form /vendor do not have the coredomain
     # attribute
     violators = []
-    for d in alldomains:
-        domain = alldomains[d]
+    for d in test_policy.alldomains:
+        domain = test_policy.alldomains[d]
         if domain.fromVendor and "coredomain" in domain.attributes:
             violators.append(d)
     if len(violators) > 0:
@@ -243,17 +237,13 @@
 ###
 # Make sure that any new public type introduced in the new policy that was not
 # present in the old policy has been recorded in the mapping file.
-def TestNoUnmappedNewTypes():
-    global alltypes
-    global oldalltypes
-    global compatMapping
-    global pubtypes
-    newt = alltypes - oldalltypes
+def TestNoUnmappedNewTypes(test_policy):
+    newt = test_policy.alltypes - test_policy.oldalltypes
     ret = ""
     violators = []
 
     for n in newt:
-        if n in pubtypes and compatMapping.rTypeattributesets.get(n) is None:
+        if n in test_policy.pubtypes and test_policy.compatMapping.rTypeattributesets.get(n) is None:
             violators.append(n)
 
     if len(violators) > 0:
@@ -270,16 +260,13 @@
 ###
 # Make sure that any public type removed in the current policy has its
 # declaration added to the mapping file for use in non-platform policy
-def TestNoUnmappedRmTypes():
-    global alltypes
-    global oldalltypes
-    global compatMapping
-    rmt = oldalltypes - alltypes
+def TestNoUnmappedRmTypes(test_policy):
+    rmt = test_policy.oldalltypes - test_policy.alltypes
     ret = ""
     violators = []
 
     for o in rmt:
-        if o in compatMapping.pubtypes and not o in compatMapping.types:
+        if o in test_policy.compatMapping.pubtypes and not o in test_policy.compatMapping.types:
             violators.append(o)
 
     if len(violators) > 0:
@@ -292,34 +279,32 @@
         ret += "https://android-review.googlesource.com/c/platform/system/sepolicy/+/822743\n"
     return ret
 
-def TestTrebleCompatMapping():
-    ret = TestNoUnmappedNewTypes()
-    ret += TestNoUnmappedRmTypes()
+def TestTrebleCompatMapping(test_policy):
+    ret = TestNoUnmappedNewTypes(test_policy)
+    ret += TestNoUnmappedRmTypes(test_policy)
     return ret
 
-def TestViolatorAttribute(attribute):
-    global FakeTreble
+def TestViolatorAttribute(test_policy, attribute):
     ret = ""
-    if FakeTreble:
+    if test_policy.FakeTreble:
         return ret
 
-    violators = DomainsWithAttribute(attribute)
+    violators = test_policy.DomainsWithAttribute(attribute)
     if len(violators) > 0:
         ret += "SELinux: The following domains violate the Treble ban "
         ret += "against use of the " + attribute + " attribute: "
         ret += " ".join(str(x) for x in sorted(violators)) + "\n"
     return ret
 
-def TestViolatorAttributes():
+def TestViolatorAttributes(test_policy):
     ret = ""
-    ret += TestViolatorAttribute("socket_between_core_and_vendor_violators")
-    ret += TestViolatorAttribute("vendor_executes_system_violators")
+    ret += TestViolatorAttribute(test_policy, "socket_between_core_and_vendor_violators")
+    ret += TestViolatorAttribute(test_policy, "vendor_executes_system_violators")
     return ret
 
 # TODO move this to sepolicy_tests
-def TestCoreDataTypeViolations():
-    global pol
-    return pol.AssertPathTypesDoNotHaveAttr(["/data/vendor/", "/data/vendor_ce/",
+def TestCoreDataTypeViolations(test_policy):
+    return test_policy.pol.AssertPathTypesDoNotHaveAttr(["/data/vendor/", "/data/vendor_ce/",
         "/data/vendor_de/"], [], "core_data_file_type")
 
 ###
@@ -349,7 +334,7 @@
     Args:
         libpath: string, path to libsepolwrap.so
     """
-    global pol, FakeTreble
+    test_policy = TestPolicy()
 
     usage = "treble_sepolicy_tests "
     usage += "-f nonplat_file_contexts -f plat_file_contexts "
@@ -402,27 +387,27 @@
         oldpol = policy.Policy(options.oldpolicy, None, libpath)
         mapping = mini_parser.MiniCilParser(options.mapping)
         pubpol = mini_parser.MiniCilParser(options.base_pub_policy)
-        compatSetup(basepol, oldpol, mapping, pubpol.types)
+        test_policy.compatSetup(basepol, oldpol, mapping, pubpol.types)
 
     if options.faketreble:
-        FakeTreble = True
+        test_policy.FakeTreble = True
 
     pol = policy.Policy(options.policy, options.file_contexts, libpath)
-    setup(pol)
+    test_policy.setup(pol)
 
     if DEBUG:
-        PrintScontexts()
+        test_policy.PrintScontexts()
 
     results = ""
     # If an individual test is not specified, run all tests.
     if options.tests is None:
         for t in Tests.values():
-            results += t()
+            results += t(test_policy)
     else:
         for tn in options.tests:
             t = Tests.get(tn)
             if t:
-                results += t()
+                results += t(test_policy)
             else:
                 err = "Error: unknown test: " + tn + "\n"
                 err += "Available tests:\n"