releasetools: Factor out the check for (compressed) APK file.

Test: Run sign_target_files.py to sign a target_files.zip.
Test: `python -m unittest test_sign_target_files_apks`
Change-Id: Ie795d1bce7bae6af427832283e3d10bfecad80c5
diff --git a/tools/releasetools/sign_target_files_apks.py b/tools/releasetools/sign_target_files_apks.py
index fa62c8f..756bc8a 100755
--- a/tools/releasetools/sign_target_files_apks.py
+++ b/tools/releasetools/sign_target_files_apks.py
@@ -144,28 +144,69 @@
   return certmap
 
 
+def GetApkFileInfo(filename, compressed_extension):
+  """Returns the APK info based on the given filename.
+
+  Checks if the given filename (with path) looks like an APK file, by taking the
+  compressed extension into consideration.
+
+  Args:
+    filename: Path to the file.
+    compressed_extension: The extension string of compressed APKs (e.g. ".gz"),
+        or None if there's no compressed APKs.
+
+  Returns:
+    (is_apk, is_compressed): is_apk indicates whether the given filename is an
+    APK file. is_compressed indicates whether the APK file is compressed (only
+    meaningful when is_apk is True).
+
+  Raises:
+    AssertionError: On invalid compressed_extension input.
+  """
+  assert compressed_extension is None or compressed_extension.startswith('.'), \
+      "Invalid compressed_extension arg: '{}'".format(compressed_extension)
+
+  compressed_apk_extension = (
+      ".apk" + compressed_extension if compressed_extension else None)
+  is_apk = (filename.endswith(".apk") or
+            (compressed_apk_extension and
+             filename.endswith(compressed_apk_extension)))
+  if not is_apk:
+    return (False, False)
+
+  is_compressed = (compressed_apk_extension and
+                   filename.endswith(compressed_apk_extension))
+  return (True, is_compressed)
+
+
 def CheckAllApksSigned(input_tf_zip, apk_key_map, compressed_extension):
-  """Check that all the APKs we want to sign have keys specified, and
-  error out if they don't."""
+  """Checks that all the APKs have keys specified, otherwise errors out.
+
+  Args:
+    input_tf_zip: An open target_files zip file.
+    apk_key_map: A dict of known signing keys key'd by APK names.
+    compressed_extension: The extension string of compressed APKs, such as
+        ".gz", or None if there's no compressed APKs.
+
+  Raises:
+    AssertionError: On finding unknown APKs.
+  """
   unknown_apks = []
-  compressed_apk_extension = None
-  if compressed_extension:
-    compressed_apk_extension = ".apk" + compressed_extension
   for info in input_tf_zip.infolist():
-    if (info.filename.endswith(".apk") or
-        (compressed_apk_extension and
-         info.filename.endswith(compressed_apk_extension))):
-      name = os.path.basename(info.filename)
-      if compressed_apk_extension and name.endswith(compressed_apk_extension):
-        name = name[:-len(compressed_extension)]
-      if name not in apk_key_map:
-        unknown_apks.append(name)
-  if unknown_apks:
-    print("ERROR: no key specified for:\n")
-    print("  " + "\n  ".join(unknown_apks))
-    print("\nUse '-e <apkname>=' to specify a key (which may be an empty "
-          "string to not sign this apk).")
-    sys.exit(1)
+    (is_apk, is_compressed) = GetApkFileInfo(
+        info.filename, compressed_extension)
+    if not is_apk:
+      continue
+    name = os.path.basename(info.filename)
+    if is_compressed:
+      name = name[:-len(compressed_extension)]
+    if name not in apk_key_map:
+      unknown_apks.append(name)
+
+  assert not unknown_apks, \
+      ("No key specified for:\n  {}\n"
+       "Use '-e <apkname>=' to specify a key (which may be an empty string to "
+       "not sign this apk).".format("\n  ".join(unknown_apks)))
 
 
 def SignApk(data, keyname, pw, platform_api_level, codename_to_api_level_map,
@@ -235,32 +276,23 @@
                        apk_key_map, key_passwords, platform_api_level,
                        codename_to_api_level_map,
                        compressed_extension):
-
-  compressed_apk_extension = None
-  if compressed_extension:
-    compressed_apk_extension = ".apk" + compressed_extension
-
   maxsize = max(
       [len(os.path.basename(i.filename)) for i in input_tf_zip.infolist()
-       if (i.filename.endswith('.apk') or
-           (compressed_apk_extension and
-            i.filename.endswith(compressed_apk_extension)))])
+       if GetApkFileInfo(i.filename, compressed_extension)[0]])
   system_root_image = misc_info.get("system_root_image") == "true"
 
   for info in input_tf_zip.infolist():
-    if info.filename.startswith("IMAGES/"):
+    filename = info.filename
+    if filename.startswith("IMAGES/"):
       continue
 
-    data = input_tf_zip.read(info.filename)
+    data = input_tf_zip.read(filename)
     out_info = copy.copy(info)
+    (is_apk, is_compressed) = GetApkFileInfo(filename, compressed_extension)
 
     # Sign APKs.
-    if (info.filename.endswith(".apk") or
-        (compressed_apk_extension and
-         info.filename.endswith(compressed_apk_extension))):
-      is_compressed = (compressed_extension and
-                       info.filename.endswith(compressed_apk_extension))
-      name = os.path.basename(info.filename)
+    if is_apk:
+      name = os.path.basename(filename)
       if is_compressed:
         name = name[:-len(compressed_extension)]
 
@@ -276,15 +308,15 @@
         common.ZipWriteStr(output_tf_zip, out_info, data)
 
     # System properties.
-    elif info.filename in ("SYSTEM/build.prop",
-                           "VENDOR/build.prop",
-                           "SYSTEM/etc/prop.default",
-                           "BOOT/RAMDISK/prop.default",
-                           "BOOT/RAMDISK/default.prop",  # legacy
-                           "ROOT/default.prop",  # legacy
-                           "RECOVERY/RAMDISK/prop.default",
-                           "RECOVERY/RAMDISK/default.prop"):  # legacy
-      print("Rewriting %s:" % (info.filename,))
+    elif filename in ("SYSTEM/build.prop",
+                      "VENDOR/build.prop",
+                      "SYSTEM/etc/prop.default",
+                      "BOOT/RAMDISK/prop.default",
+                      "BOOT/RAMDISK/default.prop",  # legacy
+                      "ROOT/default.prop",  # legacy
+                      "RECOVERY/RAMDISK/prop.default",
+                      "RECOVERY/RAMDISK/default.prop"):  # legacy
+      print("Rewriting %s:" % (filename,))
       if stat.S_ISLNK(info.external_attr >> 16):
         new_data = data
       else:
@@ -293,20 +325,20 @@
 
     # Replace the certs in *mac_permissions.xml (there could be multiple, such
     # as {system,vendor}/etc/selinux/{plat,nonplat}_mac_permissions.xml).
-    elif info.filename.endswith("mac_permissions.xml"):
-      print("Rewriting %s with new keys." % (info.filename,))
+    elif filename.endswith("mac_permissions.xml"):
+      print("Rewriting %s with new keys." % (filename,))
       new_data = ReplaceCerts(data)
       common.ZipWriteStr(output_tf_zip, out_info, new_data)
 
     # Ask add_img_to_target_files to rebuild the recovery patch if needed.
-    elif info.filename in ("SYSTEM/recovery-from-boot.p",
-                           "SYSTEM/etc/recovery.img",
-                           "SYSTEM/bin/install-recovery.sh"):
+    elif filename in ("SYSTEM/recovery-from-boot.p",
+                      "SYSTEM/etc/recovery.img",
+                      "SYSTEM/bin/install-recovery.sh"):
       OPTIONS.rebuild_recovery = True
 
     # Don't copy OTA keys if we're replacing them.
     elif (OPTIONS.replace_ota_keys and
-          info.filename in (
+          filename in (
               "BOOT/RAMDISK/res/keys",
               "BOOT/RAMDISK/etc/update_engine/update-payload-key.pub.pem",
               "RECOVERY/RAMDISK/res/keys",
@@ -315,22 +347,21 @@
       pass
 
     # Skip META/misc_info.txt since we will write back the new values later.
-    elif info.filename == "META/misc_info.txt":
+    elif filename == "META/misc_info.txt":
       pass
 
     # Skip verity public key if we will replace it.
     elif (OPTIONS.replace_verity_public_key and
-          info.filename in ("BOOT/RAMDISK/verity_key",
-                            "ROOT/verity_key")):
+          filename in ("BOOT/RAMDISK/verity_key",
+                       "ROOT/verity_key")):
       pass
 
     # Skip verity keyid (for system_root_image use) if we will replace it.
-    elif (OPTIONS.replace_verity_keyid and
-          info.filename == "BOOT/cmdline"):
+    elif OPTIONS.replace_verity_keyid and filename == "BOOT/cmdline":
       pass
 
     # Skip the care_map as we will regenerate the system/vendor images.
-    elif info.filename == "META/care_map.txt":
+    elif filename == "META/care_map.txt":
       pass
 
     # A non-APK file; copy it verbatim.
diff --git a/tools/releasetools/test_sign_target_files_apks.py b/tools/releasetools/test_sign_target_files_apks.py
index 26f9e10..71bd259 100644
--- a/tools/releasetools/test_sign_target_files_apks.py
+++ b/tools/releasetools/test_sign_target_files_apks.py
@@ -24,7 +24,8 @@
 import common
 import test_utils
 from sign_target_files_apks import (
-    EditTags, ReplaceCerts, ReplaceVerityKeyId, RewriteProps)
+    CheckAllApksSigned, EditTags, GetApkFileInfo, ReplaceCerts,
+    ReplaceVerityKeyId, RewriteProps)
 
 
 class SignTargetFilesApksTest(unittest.TestCase):
@@ -211,3 +212,50 @@
         cert2_path[:-9] : 'non-existent',
     }
     self.assertEqual(output_xml, ReplaceCerts(input_xml))
+
+  def test_CheckAllApksSigned(self):
+    input_file = common.MakeTempFile(suffix='.zip')
+    with zipfile.ZipFile(input_file, 'w') as input_zip:
+      input_zip.writestr('SYSTEM/app/App1.apk', "App1-content")
+      input_zip.writestr('SYSTEM/app/App2.apk.gz', "App2-content")
+
+    apk_key_map = {
+        'App1.apk' : 'key1',
+        'App2.apk' : 'key2',
+        'App3.apk' : 'key3',
+    }
+    with zipfile.ZipFile(input_file) as input_zip:
+      CheckAllApksSigned(input_zip, apk_key_map, None)
+      CheckAllApksSigned(input_zip, apk_key_map, '.gz')
+
+      # 'App2.apk.gz' won't be considered as an APK.
+      CheckAllApksSigned(input_zip, apk_key_map, None)
+      CheckAllApksSigned(input_zip, apk_key_map, '.xz')
+
+      del apk_key_map['App2.apk']
+      self.assertRaises(
+          AssertionError, CheckAllApksSigned, input_zip, apk_key_map, '.gz')
+
+  def test_GetApkFileInfo(self):
+    (is_apk, is_compressed) = GetApkFileInfo("PRODUCT/apps/Chats.apk", None)
+    self.assertTrue(is_apk)
+    self.assertFalse(is_compressed)
+
+    (is_apk, is_compressed) = GetApkFileInfo("PRODUCT/apps/Chats.dat", None)
+    self.assertFalse(is_apk)
+    self.assertFalse(is_compressed)
+
+  def test_GetApkFileInfo_withCompressedApks(self):
+    (is_apk, is_compressed) = GetApkFileInfo("PRODUCT/apps/Chats.apk.gz", ".gz")
+    self.assertTrue(is_apk)
+    self.assertTrue(is_compressed)
+
+    (is_apk, is_compressed) = GetApkFileInfo("PRODUCT/apps/Chats.apk.gz", ".xz")
+    self.assertFalse(is_apk)
+    self.assertFalse(is_compressed)
+
+    self.assertRaises(
+        AssertionError, GetApkFileInfo, "PRODUCT/apps/Chats.apk", "")
+
+    self.assertRaises(
+        AssertionError, GetApkFileInfo, "PRODUCT/apps/Chats.apk", "apk")