Merge "releasetools: Refactor AddImagesToTargetFiles()."
diff --git a/tools/releasetools/add_img_to_target_files.py b/tools/releasetools/add_img_to_target_files.py
index 9601d88..0628360 100755
--- a/tools/releasetools/add_img_to_target_files.py
+++ b/tools/releasetools/add_img_to_target_files.py
@@ -462,6 +462,125 @@
   img.Write()
 
 
+def AddRadioImagesForAbOta(output_zip, ab_partitions):
+  """Adds the radio images needed for A/B OTA to the output file.
+
+  It parses the list of A/B partitions, looks for the missing ones from RADIO/
+  or VENDOR_IMAGES/ dirs, and copies them to IMAGES/ of the output file (or
+  dir).
+
+  It also ensures that on returning from the function all the listed A/B
+  partitions must have their images available under IMAGES/.
+
+  Args:
+    output_zip: The output zip file (needs to be already open), or None to
+        write images to OPTIONS.input_tmp/.
+    ab_partitions: The list of A/B partitions.
+
+  Raises:
+    AssertionError: If it can't find an image.
+  """
+  for partition in ab_partitions:
+    img_name = partition.strip() + ".img"
+    prebuilt_path = os.path.join(OPTIONS.input_tmp, "IMAGES", img_name)
+    if os.path.exists(prebuilt_path):
+      print("%s already exists, no need to overwrite..." % (img_name,))
+      continue
+
+    img_radio_path = os.path.join(OPTIONS.input_tmp, "RADIO", img_name)
+    if os.path.exists(img_radio_path):
+      if output_zip:
+        common.ZipWrite(output_zip, img_radio_path,
+                        os.path.join("IMAGES", img_name))
+      else:
+        shutil.copy(img_radio_path, prebuilt_path)
+    else:
+      img_vendor_dir = os.path.join(OPTIONS.input_tmp, "VENDOR_IMAGES")
+      for root, _, files in os.walk(img_vendor_dir):
+        if img_name in files:
+          if output_zip:
+            common.ZipWrite(output_zip, os.path.join(root, img_name),
+                            os.path.join("IMAGES", img_name))
+          else:
+            shutil.copy(os.path.join(root, img_name), prebuilt_path)
+          break
+
+    if output_zip:
+      # Zip spec says: All slashes MUST be forward slashes.
+      img_path = 'IMAGES/' + img_name
+      assert img_path in output_zip.namelist(), "cannot find " + img_name
+    else:
+      img_path = os.path.join(OPTIONS.input_tmp, "IMAGES", img_name)
+      assert os.path.exists(img_path), "cannot find " + img_name
+
+
+def AddCareMapTxtForAbOta(output_zip, ab_partitions, image_paths):
+  """Generates and adds care_map.txt for system and vendor partitions.
+
+  Args:
+    output_zip: The output zip file (needs to be already open), or None to
+        write images to OPTIONS.input_tmp/.
+    ab_partitions: The list of A/B partitions.
+    image_paths: A map from the partition name to the image path.
+  """
+  care_map_list = []
+  for partition in ab_partitions:
+    partition = partition.strip()
+    if (partition == "system" and
+        ("system_verity_block_device" in OPTIONS.info_dict or
+         OPTIONS.info_dict.get("avb_system_hashtree_enable") == "true")):
+      system_img_path = image_paths[partition]
+      assert os.path.exists(system_img_path)
+      care_map_list += GetCareMap("system", system_img_path)
+    if (partition == "vendor" and
+        ("vendor_verity_block_device" in OPTIONS.info_dict or
+         OPTIONS.info_dict.get("avb_vendor_hashtree_enable") == "true")):
+      vendor_img_path = image_paths[partition]
+      assert os.path.exists(vendor_img_path)
+      care_map_list += GetCareMap("vendor", vendor_img_path)
+
+  if care_map_list:
+    care_map_path = "META/care_map.txt"
+    if output_zip and care_map_path not in output_zip.namelist():
+      common.ZipWriteStr(output_zip, care_map_path, '\n'.join(care_map_list))
+    else:
+      with open(os.path.join(OPTIONS.input_tmp, care_map_path), 'w') as fp:
+        fp.write('\n'.join(care_map_list))
+      if output_zip:
+        OPTIONS.replace_updated_files_list.append(care_map_path)
+
+
+def AddPackRadioImages(output_zip, images):
+  """Copies images listed in META/pack_radioimages.txt from RADIO/ to IMAGES/.
+
+  Args:
+    output_zip: The output zip file (needs to be already open), or None to
+        write images to OPTIONS.input_tmp/.
+    images: A list of image names.
+
+  Raises:
+    AssertionError: If a listed image can't be found.
+  """
+  for image in images:
+    img_name = image.strip()
+    _, ext = os.path.splitext(img_name)
+    if not ext:
+      img_name += ".img"
+    prebuilt_path = os.path.join(OPTIONS.input_tmp, "IMAGES", img_name)
+    if os.path.exists(prebuilt_path):
+      print("%s already exists, no need to overwrite..." % (img_name,))
+      continue
+
+    img_radio_path = os.path.join(OPTIONS.input_tmp, "RADIO", img_name)
+    assert os.path.exists(img_radio_path), \
+        "Failed to find %s at %s" % (img_name, img_radio_path)
+    if output_zip:
+      common.ZipWrite(output_zip, img_radio_path,
+                      os.path.join("IMAGES", img_name))
+    else:
+      shutil.copy(img_radio_path, prebuilt_path)
+
+
 def ReplaceUpdatedFiles(zip_filename, files_list):
   """Updates all the ZIP entries listed in files_list.
 
@@ -589,12 +708,12 @@
           recovery_two_step_image.AddToZip(output_zip)
 
   banner("system")
-  partitions['system'] = system_img_path = AddSystem(
+  partitions['system'] = AddSystem(
       output_zip, recovery_img=recovery_image, boot_img=boot_image)
 
   if has_vendor:
     banner("vendor")
-    partitions['vendor'] = vendor_img_path = AddVendor(output_zip)
+    partitions['vendor'] = AddVendor(output_zip)
 
   if has_system_other:
     banner("system_other")
@@ -618,95 +737,28 @@
     banner("vbmeta")
     AddVBMeta(output_zip, partitions)
 
-  # For devices using A/B update, copy over images from RADIO/ and/or
-  # VENDOR_IMAGES/ to IMAGES/ and make sure we have all the needed
-  # images ready under IMAGES/. All images should have '.img' as extension.
   banner("radio")
-  ab_partitions = os.path.join(OPTIONS.input_tmp, "META", "ab_partitions.txt")
-  if os.path.exists(ab_partitions):
-    with open(ab_partitions, 'r') as f:
-      lines = f.readlines()
-    # For devices using A/B update, generate care_map for system and vendor
-    # partitions (if present), then write this file to target_files package.
-    care_map_list = []
-    for line in lines:
-      if line.strip() == "system" and (
-          "system_verity_block_device" in OPTIONS.info_dict or
-          OPTIONS.info_dict.get("avb_system_hashtree_enable") == "true"):
-        assert os.path.exists(system_img_path)
-        care_map_list += GetCareMap("system", system_img_path)
-      if line.strip() == "vendor" and (
-          "vendor_verity_block_device" in OPTIONS.info_dict or
-          OPTIONS.info_dict.get("avb_vendor_hashtree_enable") == "true"):
-        assert os.path.exists(vendor_img_path)
-        care_map_list += GetCareMap("vendor", vendor_img_path)
+  ab_partitions_txt = os.path.join(OPTIONS.input_tmp, "META",
+                                   "ab_partitions.txt")
+  if os.path.exists(ab_partitions_txt):
+    with open(ab_partitions_txt, 'r') as f:
+      ab_partitions = f.readlines()
 
-      img_name = line.strip() + ".img"
-      prebuilt_path = os.path.join(OPTIONS.input_tmp, "IMAGES", img_name)
-      if os.path.exists(prebuilt_path):
-        print("%s already exists, no need to overwrite..." % (img_name,))
-        continue
+    # For devices using A/B update, copy over images from RADIO/ and/or
+    # VENDOR_IMAGES/ to IMAGES/ and make sure we have all the needed
+    # images ready under IMAGES/. All images should have '.img' as extension.
+    AddRadioImagesForAbOta(output_zip, ab_partitions)
 
-      img_radio_path = os.path.join(OPTIONS.input_tmp, "RADIO", img_name)
-      if os.path.exists(img_radio_path):
-        if output_zip:
-          common.ZipWrite(output_zip, img_radio_path,
-                          os.path.join("IMAGES", img_name))
-        else:
-          shutil.copy(img_radio_path, prebuilt_path)
-      else:
-        img_vendor_dir = os.path.join(OPTIONS.input_tmp, "VENDOR_IMAGES")
-        for root, _, files in os.walk(img_vendor_dir):
-          if img_name in files:
-            if output_zip:
-              common.ZipWrite(output_zip, os.path.join(root, img_name),
-                              os.path.join("IMAGES", img_name))
-            else:
-              shutil.copy(os.path.join(root, img_name), prebuilt_path)
-            break
-
-      if output_zip:
-        # Zip spec says: All slashes MUST be forward slashes.
-        img_path = 'IMAGES/' + img_name
-        assert img_path in output_zip.namelist(), "cannot find " + img_name
-      else:
-        img_path = os.path.join(OPTIONS.input_tmp, "IMAGES", img_name)
-        assert os.path.exists(img_path), "cannot find " + img_name
-
-    if care_map_list:
-      care_map_path = "META/care_map.txt"
-      if output_zip and care_map_path not in output_zip.namelist():
-        common.ZipWriteStr(output_zip, care_map_path, '\n'.join(care_map_list))
-      else:
-        with open(os.path.join(OPTIONS.input_tmp, care_map_path), 'w') as fp:
-          fp.write('\n'.join(care_map_list))
-        if output_zip:
-          OPTIONS.replace_updated_files_list.append(care_map_path)
+    # Generate care_map.txt for system and vendor partitions (if present), then
+    # write this file to target_files package.
+    AddCareMapTxtForAbOta(output_zip, ab_partitions, partitions)
 
   # Radio images that need to be packed into IMAGES/, and product-img.zip.
-  pack_radioimages = os.path.join(
+  pack_radioimages_txt = os.path.join(
       OPTIONS.input_tmp, "META", "pack_radioimages.txt")
-  if os.path.exists(pack_radioimages):
-    with open(pack_radioimages, 'r') as f:
-      lines = f.readlines()
-    for line in lines:
-      img_name = line.strip()
-      _, ext = os.path.splitext(img_name)
-      if not ext:
-        img_name += ".img"
-      prebuilt_path = os.path.join(OPTIONS.input_tmp, "IMAGES", img_name)
-      if os.path.exists(prebuilt_path):
-        print("%s already exists, no need to overwrite..." % (img_name,))
-        continue
-
-      img_radio_path = os.path.join(OPTIONS.input_tmp, "RADIO", img_name)
-      assert os.path.exists(img_radio_path), \
-          "Failed to find %s at %s" % (img_name, img_radio_path)
-      if output_zip:
-        common.ZipWrite(output_zip, img_radio_path,
-                        os.path.join("IMAGES", img_name))
-      else:
-        shutil.copy(img_radio_path, prebuilt_path)
+  if os.path.exists(pack_radioimages_txt):
+    with open(pack_radioimages_txt, 'r') as f:
+      AddPackRadioImages(output_zip, f.readlines())
 
   if output_zip:
     common.ZipClose(output_zip)
diff --git a/tools/releasetools/test_add_img_to_target_files.py b/tools/releasetools/test_add_img_to_target_files.py
new file mode 100644
index 0000000..e449ca8
--- /dev/null
+++ b/tools/releasetools/test_add_img_to_target_files.py
@@ -0,0 +1,168 @@
+#
+# Copyright (C) 2018 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import os.path
+import unittest
+import zipfile
+
+import common
+from add_img_to_target_files import AddPackRadioImages, AddRadioImagesForAbOta
+
+
+OPTIONS = common.OPTIONS
+
+
+class AddImagesToTargetFilesTest(unittest.TestCase):
+
+  def setUp(self):
+    OPTIONS.input_tmp = common.MakeTempDir()
+
+  def tearDown(self):
+    common.Cleanup()
+
+  @staticmethod
+  def _create_images(images, prefix):
+    """Creates images under OPTIONS.input_tmp/prefix."""
+    path = os.path.join(OPTIONS.input_tmp, prefix)
+    if not os.path.exists(path):
+      os.mkdir(path)
+
+    for image in images:
+      image_path = os.path.join(path, image + '.img')
+      with open(image_path, 'wb') as image_fp:
+        image_fp.write(image.encode())
+
+    images_path = os.path.join(OPTIONS.input_tmp, 'IMAGES')
+    if not os.path.exists(images_path):
+      os.mkdir(images_path)
+    return images, images_path
+
+  def test_AddRadioImagesForAbOta_imageExists(self):
+    """Tests the case with existing images under IMAGES/."""
+    images, images_path = self._create_images(['aboot', 'xbl'], 'IMAGES')
+    AddRadioImagesForAbOta(None, images)
+
+    for image in images:
+      self.assertTrue(
+          os.path.exists(os.path.join(images_path, image + '.img')))
+
+  def test_AddRadioImagesForAbOta_copyFromRadio(self):
+    """Tests the case that copies images from RADIO/."""
+    images, images_path = self._create_images(['aboot', 'xbl'], 'RADIO')
+    AddRadioImagesForAbOta(None, images)
+
+    for image in images:
+      self.assertTrue(
+          os.path.exists(os.path.join(images_path, image + '.img')))
+
+  def test_AddRadioImagesForAbOta_copyFromRadio_zipOutput(self):
+    images, _ = self._create_images(['aboot', 'xbl'], 'RADIO')
+
+    # Set up the output zip.
+    output_file = common.MakeTempFile(suffix='.zip')
+    with zipfile.ZipFile(output_file, 'w') as output_zip:
+      AddRadioImagesForAbOta(output_zip, images)
+
+    with zipfile.ZipFile(output_file, 'r') as verify_zip:
+      for image in images:
+        self.assertIn('IMAGES/' + image + '.img', verify_zip.namelist())
+
+  def test_AddRadioImagesForAbOta_copyFromVendorImages(self):
+    """Tests the case that copies images from VENDOR_IMAGES/."""
+    vendor_images_path = os.path.join(OPTIONS.input_tmp, 'VENDOR_IMAGES')
+    os.mkdir(vendor_images_path)
+
+    partitions = ['aboot', 'xbl']
+    for index, partition in enumerate(partitions):
+      subdir = os.path.join(vendor_images_path, 'subdir-{}'.format(index))
+      os.mkdir(subdir)
+
+      partition_image_path = os.path.join(subdir, partition + '.img')
+      with open(partition_image_path, 'wb') as partition_fp:
+        partition_fp.write(partition.encode())
+
+    # Set up the output dir.
+    images_path = os.path.join(OPTIONS.input_tmp, 'IMAGES')
+    os.mkdir(images_path)
+
+    AddRadioImagesForAbOta(None, partitions)
+
+    for partition in partitions:
+      self.assertTrue(
+          os.path.exists(os.path.join(images_path, partition + '.img')))
+
+  def test_AddRadioImagesForAbOta_missingImages(self):
+    images, _ = self._create_images(['aboot', 'xbl'], 'RADIO')
+    self.assertRaises(AssertionError, AddRadioImagesForAbOta, None,
+                      images + ['baz'])
+
+  def test_AddRadioImagesForAbOta_missingImages_zipOutput(self):
+    images, _ = self._create_images(['aboot', 'xbl'], 'RADIO')
+
+    # Set up the output zip.
+    output_file = common.MakeTempFile(suffix='.zip')
+    with zipfile.ZipFile(output_file, 'w') as output_zip:
+      self.assertRaises(AssertionError, AddRadioImagesForAbOta, output_zip,
+                        images + ['baz'])
+
+  def test_AddPackRadioImages(self):
+    images, images_path = self._create_images(['foo', 'bar'], 'RADIO')
+    AddPackRadioImages(None, images)
+
+    for image in images:
+      self.assertTrue(
+          os.path.exists(os.path.join(images_path, image + '.img')))
+
+  def test_AddPackRadioImages_with_suffix(self):
+    images, images_path = self._create_images(['foo', 'bar'], 'RADIO')
+    images_with_suffix = [image + '.img' for image in images]
+    AddPackRadioImages(None, images_with_suffix)
+
+    for image in images:
+      self.assertTrue(
+          os.path.exists(os.path.join(images_path, image + '.img')))
+
+  def test_AddPackRadioImages_zipOutput(self):
+    images, _ = self._create_images(['foo', 'bar'], 'RADIO')
+
+    # Set up the output zip.
+    output_file = common.MakeTempFile(suffix='.zip')
+    with zipfile.ZipFile(output_file, 'w') as output_zip:
+      AddPackRadioImages(output_zip, images)
+
+    with zipfile.ZipFile(output_file, 'r') as verify_zip:
+      for image in images:
+        self.assertIn('IMAGES/' + image + '.img', verify_zip.namelist())
+
+  def test_AddPackRadioImages_imageExists(self):
+    images, images_path = self._create_images(['foo', 'bar'], 'RADIO')
+
+    # Additionally create images under IMAGES/ so that they should be skipped.
+    images, images_path = self._create_images(['foo', 'bar'], 'IMAGES')
+
+    AddPackRadioImages(None, images)
+
+    for image in images:
+      self.assertTrue(
+          os.path.exists(os.path.join(images_path, image + '.img')))
+
+  def test_AddPackRadioImages_missingImages(self):
+    images, _ = self._create_images(['foo', 'bar'], 'RADIO')
+    AddPackRadioImages(None, images)
+
+    self.assertRaises(AssertionError, AddPackRadioImages, None,
+                      images + ['baz'])