Refactor treble_sepolicy_tests.py
Introduce a new class TestPolicy to capture all the previous global
variables. This class contains the constructor and loading methods
(Get*) to load its internal state. The tests are modified to accept a
TestPolicy as first argument.
This commit is a no-op. There is no change to the tests.
`git show --ignore-space-change` can be used to skip over the
re-indentation due to the new class.
Bug: 269182257
Test: m selinux_policy (runs treble_sepolicy_tests against all
compatible versions)
Test: Set DEBUG=True, compare generated scontexts. Identical.
Change-Id: Ia8da115dc1c0109b835e03b95da029b35712d251
diff --git a/tests/treble_sepolicy_tests.py b/tests/treble_sepolicy_tests.py
index b49f138..c966423 100644
--- a/tests/treble_sepolicy_tests.py
+++ b/tests/treble_sepolicy_tests.py
@@ -51,172 +51,166 @@
self.entrypointpaths = []
self.error = ""
-def PrintScontexts():
- for d in sorted(alldomains.keys()):
- sctx = alldomains[d]
- print(d)
- print("\tcoredomain="+str(sctx.coredomain))
- print("\tappdomain="+str(sctx.appdomain))
- print("\tfromSystem="+str(sctx.fromSystem))
- print("\tfromVendor="+str(sctx.fromVendor))
- print("\tattributes="+str(sctx.attributes))
- print("\tentrypoints="+str(sctx.entrypoints))
- print("\tentrypointpaths=")
- if sctx.entrypointpaths is not None:
- for path in sctx.entrypointpaths:
- print("\t\t"+str(path))
-alldomains = {}
-coredomains = set()
-appdomains = set()
-vendordomains = set()
-pol = None
+class TestPolicy:
+ """A policy loaded in memory with its domains easily accessible."""
-# compat vars
-alltypes = set()
-oldalltypes = set()
-compatMapping = None
-pubtypes = set()
+ def __init__(self):
+ self.alldomains = {}
+ self.coredomains = set()
+ self.appdomains = set()
+ self.vendordomains = set()
+ self.pol = None
-# Distinguish between PRODUCT_FULL_TREBLE and PRODUCT_FULL_TREBLE_OVERRIDE
-FakeTreble = False
+ # compat vars
+ self.alltypes = set()
+ self.oldalltypes = set()
+ self.compatMapping = None
+ self.pubtypes = set()
-def GetAllDomains(pol):
- global alldomains
- for result in pol.QueryTypeAttribute("domain", True):
- alldomains[result] = scontext()
+ # Distinguish between PRODUCT_FULL_TREBLE and PRODUCT_FULL_TREBLE_OVERRIDE
+ self.FakeTreble = False
-def GetAppDomains():
- global appdomains
- global alldomains
- for d in alldomains:
- # The application of the "appdomain" attribute is trusted because core
- # selinux policy contains neverallow rules that enforce that only zygote
- # and runas spawned processes may transition to processes that have
- # the appdomain attribute.
- if "appdomain" in alldomains[d].attributes:
- alldomains[d].appdomain = True
- appdomains.add(d)
+ def GetAllDomains(self):
+ for result in self.pol.QueryTypeAttribute("domain", True):
+ self.alldomains[result] = scontext()
-def GetCoreDomains():
- global alldomains
- global coredomains
- for d in alldomains:
- domain = alldomains[d]
- # TestCoredomainViolations will verify if coredomain was incorrectly
- # applied.
- if "coredomain" in domain.attributes:
- domain.coredomain = True
- coredomains.add(d)
- # check whether domains are executed off of /system or /vendor
- if d in coredomainAllowlist:
- continue
- # TODO(b/153112003): add checks to prevent app domains from being
- # incorrectly labeled as coredomain. Apps don't have entrypoints as
- # they're always dynamically transitioned to by zygote.
- if d in appdomains:
- continue
- # TODO(b/153112747): need to handle cases where there is a dynamic
- # transition OR there happens to be no context in AOSP files.
- if not domain.entrypointpaths:
- continue
+ def GetAppDomains(self):
+ for d in self.alldomains:
+ # The application of the "appdomain" attribute is trusted because core
+ # selinux policy contains neverallow rules that enforce that only zygote
+ # and runas spawned processes may transition to processes that have
+ # the appdomain attribute.
+ if "appdomain" in self.alldomains[d].attributes:
+ self.alldomains[d].appdomain = True
+ self.appdomains.add(d)
- for path in domain.entrypointpaths:
- vendor = any(MatchPathPrefix(path, prefix) for prefix in
- ["/vendor", "/odm"])
- system = any(MatchPathPrefix(path, prefix) for prefix in
- ["/init", "/system_ext", "/product" ])
+ def GetCoreDomains(self):
+ for d in self.alldomains:
+ domain = self.alldomains[d]
+ # TestCoredomainViolations will verify if coredomain was incorrectly
+ # applied.
+ if "coredomain" in domain.attributes:
+ domain.coredomain = True
+ self.coredomains.add(d)
+ # check whether domains are executed off of /system or /vendor
+ if d in coredomainAllowlist:
+ continue
+ # TODO(b/153112003): add checks to prevent app domains from being
+ # incorrectly labeled as coredomain. Apps don't have entrypoints as
+ # they're always dynamically transitioned to by zygote.
+ if d in self.appdomains:
+ continue
+ # TODO(b/153112747): need to handle cases where there is a dynamic
+ # transition OR there happens to be no context in AOSP files.
+ if not domain.entrypointpaths:
+ continue
- # only mark entrypoint as system if it is not in legacy /system/vendor
- if MatchPathPrefix(path, "/system/vendor"):
- vendor = True
- elif MatchPathPrefix(path, "/system"):
- system = True
+ for path in domain.entrypointpaths:
+ vendor = any(MatchPathPrefix(path, prefix) for prefix in
+ ["/vendor", "/odm"])
+ system = any(MatchPathPrefix(path, prefix) for prefix in
+ ["/init", "/system_ext", "/product" ])
- if not vendor and not system:
- domain.error += "Unrecognized entrypoint for " + d + " at " + path + "\n"
+ # only mark entrypoint as system if it is not in legacy /system/vendor
+ if MatchPathPrefix(path, "/system/vendor"):
+ vendor = True
+ elif MatchPathPrefix(path, "/system"):
+ system = True
- domain.fromSystem = domain.fromSystem or system
- domain.fromVendor = domain.fromVendor or vendor
+ if not vendor and not system:
+ domain.error += "Unrecognized entrypoint for " + d + " at " + path + "\n"
-###
-# Add the entrypoint type and path(s) to each domain.
-#
-def GetDomainEntrypoints(pol):
- global alldomains
- for x in pol.QueryExpandedTERule(tclass=set(["file"]), perms=set(["entrypoint"])):
- if not x.sctx in alldomains:
- continue
- alldomains[x.sctx].entrypoints.append(str(x.tctx))
- # postinstall_file represents a special case specific to A/B OTAs.
- # Update_engine mounts a partition and relabels it postinstall_file.
- # There is no file_contexts entry associated with postinstall_file
- # so skip the lookup.
- if x.tctx == "postinstall_file":
- continue
- entrypointpath = pol.QueryFc(x.tctx)
- if not entrypointpath:
- continue
- alldomains[x.sctx].entrypointpaths.extend(entrypointpath)
-###
-# Get attributes associated with each domain
-#
-def GetAttributes(pol):
- global alldomains
- for domain in alldomains:
- for result in pol.QueryTypeAttribute(domain, False):
- alldomains[domain].attributes.add(result)
+ domain.fromSystem = domain.fromSystem or system
+ domain.fromVendor = domain.fromVendor or vendor
-def GetAllTypes(pol, oldpol):
- global alltypes
- global oldalltypes
- alltypes = pol.GetAllTypes(False)
- oldalltypes = oldpol.GetAllTypes(False)
+ ###
+ # Add the entrypoint type and path(s) to each domain.
+ #
+ def GetDomainEntrypoints(self):
+ for x in self.pol.QueryExpandedTERule(tclass=set(["file"]), perms=set(["entrypoint"])):
+ if not x.sctx in self.alldomains:
+ continue
+ self.alldomains[x.sctx].entrypoints.append(str(x.tctx))
+ # postinstall_file represents a special case specific to A/B OTAs.
+ # Update_engine mounts a partition and relabels it postinstall_file.
+ # There is no file_contexts entry associated with postinstall_file
+ # so skip the lookup.
+ if x.tctx == "postinstall_file":
+ continue
+ entrypointpath = self.pol.QueryFc(x.tctx)
+ if not entrypointpath:
+ continue
+ self.alldomains[x.sctx].entrypointpaths.extend(entrypointpath)
-def setup(pol):
- GetAllDomains(pol)
- GetAttributes(pol)
- GetDomainEntrypoints(pol)
- GetAppDomains()
- GetCoreDomains()
+ ###
+ # Get attributes associated with each domain
+ #
+ def GetAttributes(self):
+ for domain in self.alldomains:
+ for result in self.pol.QueryTypeAttribute(domain, False):
+ self.alldomains[domain].attributes.add(result)
-# setup for the policy compatibility tests
-def compatSetup(pol, oldpol, mapping, types):
- global compatMapping
- global pubtypes
+ def setup(self, pol):
+ self.pol = pol
+ self.GetAllDomains()
+ self.GetAttributes()
+ self.GetDomainEntrypoints()
+ self.GetAppDomains()
+ self.GetCoreDomains()
- GetAllTypes(pol, oldpol)
- compatMapping = mapping
- pubtypes = types
+ def GetAllTypes(self, basepol, oldpol):
+ self.alltypes = basepol.GetAllTypes(False)
+ self.oldalltypes = oldpol.GetAllTypes(False)
-def DomainsWithAttribute(attr):
- global alldomains
- domains = []
- for domain in alldomains:
- if attr in alldomains[domain].attributes:
- domains.append(domain)
- return domains
+ # setup for the policy compatibility tests
+ def compatSetup(self, basepol, oldpol, mapping, types):
+ self.GetAllTypes(basepol, oldpol)
+ self.compatMapping = mapping
+ self.pubtypes = types
+
+ def DomainsWithAttribute(self, attr):
+ domains = []
+ for domain in self.alldomains:
+ if attr in self.alldomains[domain].attributes:
+ domains.append(domain)
+ return domains
+
+ def PrintScontexts(self):
+ for d in sorted(self.alldomains.keys()):
+ sctx = self.alldomains[d]
+ print(d)
+ print("\tcoredomain="+str(sctx.coredomain))
+ print("\tappdomain="+str(sctx.appdomain))
+ print("\tfromSystem="+str(sctx.fromSystem))
+ print("\tfromVendor="+str(sctx.fromVendor))
+ print("\tattributes="+str(sctx.attributes))
+ print("\tentrypoints="+str(sctx.entrypoints))
+ print("\tentrypointpaths=")
+ if sctx.entrypointpaths is not None:
+ for path in sctx.entrypointpaths:
+ print("\t\t"+str(path))
+
#############################################################
# Tests
#############################################################
-def TestCoredomainViolations():
- global alldomains
+def TestCoredomainViolations(test_policy):
# verify that all domains launched from /system have the coredomain
# attribute
ret = ""
- for d in alldomains:
- domain = alldomains[d]
+ for d in test_policy.alldomains:
+ domain = test_policy.alldomains[d]
if domain.fromSystem and domain.fromVendor:
ret += "The following domain is system and vendor: " + d + "\n"
- for domain in alldomains.values():
+ for domain in test_policy.alldomains.values():
ret += domain.error
violators = []
- for d in alldomains:
- domain = alldomains[d]
+ for d in test_policy.alldomains:
+ domain = test_policy.alldomains[d]
if domain.fromSystem and "coredomain" not in domain.attributes:
violators.append(d);
if len(violators) > 0:
@@ -228,8 +222,8 @@
# verify that all domains launched form /vendor do not have the coredomain
# attribute
violators = []
- for d in alldomains:
- domain = alldomains[d]
+ for d in test_policy.alldomains:
+ domain = test_policy.alldomains[d]
if domain.fromVendor and "coredomain" in domain.attributes:
violators.append(d)
if len(violators) > 0:
@@ -243,17 +237,13 @@
###
# Make sure that any new public type introduced in the new policy that was not
# present in the old policy has been recorded in the mapping file.
-def TestNoUnmappedNewTypes():
- global alltypes
- global oldalltypes
- global compatMapping
- global pubtypes
- newt = alltypes - oldalltypes
+def TestNoUnmappedNewTypes(test_policy):
+ newt = test_policy.alltypes - test_policy.oldalltypes
ret = ""
violators = []
for n in newt:
- if n in pubtypes and compatMapping.rTypeattributesets.get(n) is None:
+ if n in test_policy.pubtypes and test_policy.compatMapping.rTypeattributesets.get(n) is None:
violators.append(n)
if len(violators) > 0:
@@ -270,16 +260,13 @@
###
# Make sure that any public type removed in the current policy has its
# declaration added to the mapping file for use in non-platform policy
-def TestNoUnmappedRmTypes():
- global alltypes
- global oldalltypes
- global compatMapping
- rmt = oldalltypes - alltypes
+def TestNoUnmappedRmTypes(test_policy):
+ rmt = test_policy.oldalltypes - test_policy.alltypes
ret = ""
violators = []
for o in rmt:
- if o in compatMapping.pubtypes and not o in compatMapping.types:
+ if o in test_policy.compatMapping.pubtypes and not o in test_policy.compatMapping.types:
violators.append(o)
if len(violators) > 0:
@@ -292,34 +279,32 @@
ret += "https://android-review.googlesource.com/c/platform/system/sepolicy/+/822743\n"
return ret
-def TestTrebleCompatMapping():
- ret = TestNoUnmappedNewTypes()
- ret += TestNoUnmappedRmTypes()
+def TestTrebleCompatMapping(test_policy):
+ ret = TestNoUnmappedNewTypes(test_policy)
+ ret += TestNoUnmappedRmTypes(test_policy)
return ret
-def TestViolatorAttribute(attribute):
- global FakeTreble
+def TestViolatorAttribute(test_policy, attribute):
ret = ""
- if FakeTreble:
+ if test_policy.FakeTreble:
return ret
- violators = DomainsWithAttribute(attribute)
+ violators = test_policy.DomainsWithAttribute(attribute)
if len(violators) > 0:
ret += "SELinux: The following domains violate the Treble ban "
ret += "against use of the " + attribute + " attribute: "
ret += " ".join(str(x) for x in sorted(violators)) + "\n"
return ret
-def TestViolatorAttributes():
+def TestViolatorAttributes(test_policy):
ret = ""
- ret += TestViolatorAttribute("socket_between_core_and_vendor_violators")
- ret += TestViolatorAttribute("vendor_executes_system_violators")
+ ret += TestViolatorAttribute(test_policy, "socket_between_core_and_vendor_violators")
+ ret += TestViolatorAttribute(test_policy, "vendor_executes_system_violators")
return ret
# TODO move this to sepolicy_tests
-def TestCoreDataTypeViolations():
- global pol
- return pol.AssertPathTypesDoNotHaveAttr(["/data/vendor/", "/data/vendor_ce/",
+def TestCoreDataTypeViolations(test_policy):
+ return test_policy.pol.AssertPathTypesDoNotHaveAttr(["/data/vendor/", "/data/vendor_ce/",
"/data/vendor_de/"], [], "core_data_file_type")
###
@@ -349,7 +334,7 @@
Args:
libpath: string, path to libsepolwrap.so
"""
- global pol, FakeTreble
+ test_policy = TestPolicy()
usage = "treble_sepolicy_tests "
usage += "-f nonplat_file_contexts -f plat_file_contexts "
@@ -402,27 +387,27 @@
oldpol = policy.Policy(options.oldpolicy, None, libpath)
mapping = mini_parser.MiniCilParser(options.mapping)
pubpol = mini_parser.MiniCilParser(options.base_pub_policy)
- compatSetup(basepol, oldpol, mapping, pubpol.types)
+ test_policy.compatSetup(basepol, oldpol, mapping, pubpol.types)
if options.faketreble:
- FakeTreble = True
+ test_policy.FakeTreble = True
pol = policy.Policy(options.policy, options.file_contexts, libpath)
- setup(pol)
+ test_policy.setup(pol)
if DEBUG:
- PrintScontexts()
+ test_policy.PrintScontexts()
results = ""
# If an individual test is not specified, run all tests.
if options.tests is None:
for t in Tests.values():
- results += t()
+ results += t(test_policy)
else:
for tn in options.tests:
t = Tests.get(tn)
if t:
- results += t()
+ results += t(test_policy)
else:
err = "Error: unknown test: " + tn + "\n"
err += "Available tests:\n"