update_payload: Allow specifying partition options for major version 2

This commit adds the ability to specify partition options for more than just
kernel/rootfs.

This supersedes -p/--root-part-size, -P/--kern-part-size, --dst_kern,
--dst_root, --src_kern, --src_root, --out_dst_kern, and --out_dst_root.

They are replaced by --part_names used in conjunction with --part_sizes,
--dst_part_paths, --src_part_paths, and --out_dst_part_paths.

Backwards-compatibility with the old flags is kept, so long as they are
not used alongside the new flags.

BUG=b:794404
TEST=no errors during run_unittests and test_paycheck.sh

Change-Id: Icc1118abbf89dd268be3eafe41723657c5178197
Reviewed-on: https://chromium-review.googlesource.com/1103063
Commit-Ready: Tudor Brindus <tbrindus@chromium.org>
Tested-by: Tudor Brindus <tbrindus@chromium.org>
Reviewed-by: Amin Hassani <ahassani@chromium.org>
diff --git a/scripts/paycheck.py b/scripts/paycheck.py
index 0050f5b..9d61778 100755
--- a/scripts/paycheck.py
+++ b/scripts/paycheck.py
@@ -26,6 +26,7 @@
 import sys
 import tempfile
 
+from update_payload import common
 from update_payload import error
 
 lib_dir = os.path.join(os.path.dirname(__file__), 'lib')
@@ -92,12 +93,15 @@
   check_args.add_argument('-s', '--metadata-size', metavar='NUM', default=0,
                           help='the metadata size to verify with the one in'
                           ' payload')
+  # TODO(tbrindus): deprecated in favour of --part_sizes
   check_args.add_argument('-p', '--root-part-size', metavar='NUM',
                           default=0, type=int,
                           help='override rootfs partition size auto-inference')
   check_args.add_argument('-P', '--kern-part-size', metavar='NUM',
                           default=0, type=int,
                           help='override kernel partition size auto-inference')
+  check_args.add_argument('--part_sizes', metavar='NUM', nargs='+', type=int,
+                          help='override partition size auto-inference')
 
   apply_args = parser.add_argument_group('Applying payload')
   # TODO(ahassani): Extent extract-bsdiff to puffdiff too.
@@ -109,42 +113,66 @@
                           help='use the specified bspatch binary')
   apply_args.add_argument('--puffpatch-path', metavar='FILE',
                           help='use the specified puffpatch binary')
+  # TODO(tbrindus): deprecated in favour of --dst_part_paths
   apply_args.add_argument('--dst_kern', metavar='FILE',
                           help='destination kernel partition file')
   apply_args.add_argument('--dst_root', metavar='FILE',
                           help='destination root partition file')
+  # TODO(tbrindus): deprecated in favour of --src_part_paths
   apply_args.add_argument('--src_kern', metavar='FILE',
                           help='source kernel partition file')
   apply_args.add_argument('--src_root', metavar='FILE',
                           help='source root partition file')
+  # TODO(tbrindus): deprecated in favour of --out_dst_part_paths
   apply_args.add_argument('--out_dst_kern', metavar='FILE',
                           help='created destination kernel partition file')
   apply_args.add_argument('--out_dst_root', metavar='FILE',
                           help='created destination root partition file')
 
+  apply_args.add_argument('--src_part_paths', metavar='FILE', nargs='+',
+                          help='source partitition files')
+  apply_args.add_argument('--dst_part_paths', metavar='FILE', nargs='+',
+                          help='destination partition files')
+  apply_args.add_argument('--out_dst_part_paths', metavar='FILE', nargs='+',
+                          help='created destination partition files')
+
   parser.add_argument('payload', metavar='PAYLOAD', help='the payload file')
+  parser.add_argument('--part_names', metavar='NAME', nargs='+',
+                      help='names of partitions')
 
   # Parse command-line arguments.
   args = parser.parse_args(argv)
 
+  # TODO(tbrindus): temporary workaround to keep old-style flags from breaking
+  # without having to handle both types in our code. Remove after flag usage is
+  # removed from calling scripts.
+  args.part_names = args.part_names or [common.KERNEL, common.ROOTFS]
+  args.part_sizes = args.part_sizes or [args.kern_part_size,
+                                        args.root_part_size]
+  args.src_part_paths = args.src_part_paths or [args.src_kern, args.src_root]
+  args.dst_part_paths = args.dst_part_paths or [args.dst_kern, args.dst_root]
+  args.out_dst_part_paths = args.out_dst_part_paths or [args.out_dst_kern,
+                                                        args.out_dst_root]
+
+  # Make sure we don't have new dependencies on old flags by deleting them from
+  # the namespace here.
+  for old in ['kern_part_size', 'root_part_size', 'src_kern', 'src_root',
+              'dst_kern', 'dst_root', 'out_dst_kern', 'out_dst_root']:
+    delattr(args, old)
+
   # There are several options that imply --check.
   args.check = (args.check or args.report or args.assert_type or
                 args.block_size or args.allow_unhashed or
                 args.disabled_tests or args.meta_sig or args.key or
-                args.root_part_size or args.kern_part_size or
-                args.metadata_size)
+                any(args.part_sizes) or args.metadata_size)
 
-  # Check the arguments, enforce payload type accordingly.
-  if (args.src_kern is None) != (args.src_root is None):
-    parser.error('--src_kern and --src_root should be given together')
-  if (args.dst_kern is None) != (args.dst_root is None):
-    parser.error('--dst_kern and --dst_root should be given together')
-  if (args.out_dst_kern is None) != (args.out_dst_root is None):
-    parser.error('--out_dst_kern and --out_dst_root should be given together')
+  for arg in ['part_sizes', 'src_part_paths', 'dst_part_paths',
+              'out_dst_part_paths']:
+    if len(args.part_names) != len(getattr(args, arg, [])):
+      parser.error('partitions in --%s do not match --part_names' % arg)
 
-  if (args.dst_kern and args.dst_root) or \
-     (args.out_dst_kern and args.out_dst_root):
-    if args.src_kern and args.src_root:
+  if all(args.dst_part_paths) or all(args.out_dst_part_paths):
+    if all(args.src_part_paths):
       if args.assert_type == _TYPE_FULL:
         parser.error('%s payload does not accept source partition arguments'
                      % _TYPE_FULL)
@@ -202,6 +230,7 @@
               report_file = open(args.report, 'w')
               do_close_report_file = True
 
+          part_sizes = dict(zip(args.part_names, args.part_sizes))
           metadata_sig_file = args.meta_sig and open(args.meta_sig)
           payload.Check(
               pubkey_file_name=args.key,
@@ -210,8 +239,7 @@
               report_out_file=report_file,
               assert_type=args.assert_type,
               block_size=int(args.block_size),
-              rootfs_part_size=args.root_part_size,
-              kernel_part_size=args.kern_part_size,
+              part_sizes=part_sizes,
               allow_unhashed=args.allow_unhashed,
               disabled_tests=args.disabled_tests)
         finally:
@@ -221,46 +249,50 @@
             report_file.close()
 
       # Apply payload.
-      if (args.dst_root and args.dst_kern) or \
-         (args.out_dst_root and args.out_dst_kern):
+      if all(args.dst_part_paths) or all(args.out_dst_part_paths):
         dargs = {'bsdiff_in_place': not args.extract_bsdiff}
         if args.bspatch_path:
           dargs['bspatch_path'] = args.bspatch_path
         if args.puffpatch_path:
           dargs['puffpatch_path'] = args.puffpatch_path
         if args.assert_type == _TYPE_DELTA:
-          dargs['old_kernel_part'] = args.src_kern
-          dargs['old_rootfs_part'] = args.src_root
+          dargs['old_parts'] = dict(zip(args.part_names, args.src_part_paths))
 
-        if args.out_dst_kern and args.out_dst_root:
-          out_dst_kern = open(args.out_dst_kern, 'w+')
-          out_dst_root = open(args.out_dst_root, 'w+')
+        out_dst_parts = {}
+        file_handles = []
+        if all(args.out_dst_part_paths):
+          for name, path in zip(args.part_names, args.out_dst_part_paths):
+            handle = open(path, 'w+')
+            file_handles.append(handle)
+            out_dst_parts[name] = handle.name
         else:
-          out_dst_kern = tempfile.NamedTemporaryFile()
-          out_dst_root = tempfile.NamedTemporaryFile()
+          for name in args.part_names:
+            handle = tempfile.NamedTemporaryFile()
+            file_handles.append(handle)
+            out_dst_parts[name] = handle.name
 
-        payload.Apply(out_dst_kern.name, out_dst_root.name, **dargs)
+        payload.Apply(out_dst_parts, **dargs)
 
         # If destination kernel and rootfs partitions are not given, then this
         # just becomes an apply operation with no check.
-        if args.dst_kern and args.dst_root:
+        if all(args.dst_part_paths):
           # Prior to comparing, add the unused space past the filesystem
           # boundary in the new target partitions to become the same size as
           # the given partitions. This will truncate to larger size.
-          out_dst_kern.truncate(os.path.getsize(args.dst_kern))
-          out_dst_root.truncate(os.path.getsize(args.dst_root))
+          for part_name, out_dst_part, dst_part in zip(args.part_names,
+                                                       file_handles,
+                                                       args.dst_part_paths):
+            out_dst_part.truncate(os.path.getsize(dst_part))
 
-          # Compare resulting partitions with the ones from the target image.
-          if not filecmp.cmp(out_dst_kern.name, args.dst_kern):
-            raise error.PayloadError('Resulting kernel partition corrupted.')
-          if not filecmp.cmp(out_dst_root.name, args.dst_root):
-            raise error.PayloadError('Resulting rootfs partition corrupted.')
+            # Compare resulting partitions with the ones from the target image.
+            if not filecmp.cmp(out_dst_part.name, dst_part):
+              raise error.PayloadError(
+                  'Resulting %s partition corrupted.' % part_name)
 
         # Close the output files. If args.out_dst_* was not given, then these
         # files are created as temp files and will be deleted upon close().
-        out_dst_kern.close()
-        out_dst_root.close()
-
+        for handle in file_handles:
+          handle.close()
     except error.PayloadError, e:
       sys.stderr.write('Error: %s\n' % e)
       return 1
diff --git a/scripts/update_payload/applier.py b/scripts/update_payload/applier.py
index 9582b3d..dad5ba3 100644
--- a/scripts/update_payload/applier.py
+++ b/scripts/update_payload/applier.py
@@ -622,21 +622,24 @@
       _VerifySha256(new_part_file, new_part_info.hash,
                     'new ' + part_name, length=new_part_info.size)
 
-  def Run(self, new_kernel_part, new_rootfs_part, old_kernel_part=None,
-          old_rootfs_part=None):
+  def Run(self, new_parts, old_parts=None):
     """Applier entry point, invoking all update operations.
 
     Args:
-      new_kernel_part: name of dest kernel partition file
-      new_rootfs_part: name of dest rootfs partition file
-      old_kernel_part: name of source kernel partition file (optional)
-      old_rootfs_part: name of source rootfs partition file (optional)
+      new_parts: map of partition name to dest partition file
+      old_parts: map of partition name to source partition file (optional)
 
     Raises:
       PayloadError if payload application failed.
     """
     self.payload.ResetFile()
 
+    # TODO(tbrindus): make payload applying work for major version 2 partitions
+    new_kernel_part = new_parts[common.KERNEL]
+    new_rootfs_part = new_parts[common.ROOTFS]
+    old_kernel_part = old_parts.get(common.KERNEL, None) if old_parts else None
+    old_rootfs_part = old_parts.get(common.ROOTFS, None) if old_parts else None
+
     # Make sure the arguments are sane and match the payload.
     if not (new_kernel_part and new_rootfs_part):
       raise PayloadError('missing dst {kernel,rootfs} partitions')
diff --git a/scripts/update_payload/checker.py b/scripts/update_payload/checker.py
index 49c556f..ec8810d 100644
--- a/scripts/update_payload/checker.py
+++ b/scripts/update_payload/checker.py
@@ -561,8 +561,7 @@
     Raises:
       error.PayloadError if any of the checks fail.
     """
-    if part_sizes is None:
-      part_sizes = collections.defaultdict(int)
+    part_sizes = collections.defaultdict(int, part_sizes)
 
     manifest = self.payload.manifest
     report.AddSection('manifest')
@@ -1226,17 +1225,15 @@
                                  sig.version)
 
   def Run(self, pubkey_file_name=None, metadata_sig_file=None, metadata_size=0,
-          rootfs_part_size=0, kernel_part_size=0, report_out_file=None):
+          part_sizes=None, report_out_file=None):
     """Checker entry point, invoking all checks.
 
     Args:
       pubkey_file_name: Public key used for signature verification.
       metadata_sig_file: Metadata signature, if verification is desired.
-      metadata_size: metadata size, if verification is desired
-      rootfs_part_size: The size of rootfs partitions in bytes (default: infer
-                        based on payload type and version).
-      kernel_part_size: The size of kernel partitions in bytes (default: use
-                        reported filesystem size).
+      metadata_size: Metadata size, if verification is desired.
+      part_sizes: Mapping of partition label to size in bytes (default: infer
+        based on payload type and version or filesystem).
       report_out_file: File object to dump the report to.
 
     Raises:
@@ -1276,10 +1273,7 @@
       report.AddField('manifest len', self.payload.header.manifest_len)
 
       # Part 2: Check the manifest.
-      self._CheckManifest(report, {
-          common.ROOTFS: rootfs_part_size,
-          common.KERNEL: kernel_part_size
-      })
+      self._CheckManifest(report, part_sizes)
       assert self.payload_type, 'payload type should be known by now'
 
       # Infer the usable partition size when validating rootfs operations:
@@ -1290,9 +1284,9 @@
       # - Otherwise, use the encoded filesystem size.
       new_rootfs_usable_size = self.new_fs_sizes[common.ROOTFS]
       old_rootfs_usable_size = self.old_fs_sizes[common.ROOTFS]
-      if rootfs_part_size:
-        new_rootfs_usable_size = rootfs_part_size
-        old_rootfs_usable_size = rootfs_part_size
+      if part_sizes.get(common.ROOTFS, 0):
+        new_rootfs_usable_size = part_sizes[common.ROOTFS]
+        old_rootfs_usable_size = part_sizes[common.ROOTFS]
       elif self.payload_type == _TYPE_DELTA and self.minor_version in (None, 1):
         new_rootfs_usable_size = _OLD_DELTA_USABLE_PART_SIZE
         old_rootfs_usable_size = _OLD_DELTA_USABLE_PART_SIZE
@@ -1313,6 +1307,7 @@
       report.AddSection('kernel operations')
       old_kernel_fs_size = self.old_fs_sizes[common.KERNEL]
       new_kernel_fs_size = self.new_fs_sizes[common.KERNEL]
+      kernel_part_size = part_sizes.get(common.KERNEL, None)
       total_blob_size += self._CheckOperations(
           self.payload.manifest.kernel_install_operations, report,
           'kernel_install_operations', old_kernel_fs_size, new_kernel_fs_size,
diff --git a/scripts/update_payload/checker_unittest.py b/scripts/update_payload/checker_unittest.py
index d42f4b4..ed5ee80 100755
--- a/scripts/update_payload/checker_unittest.py
+++ b/scripts/update_payload/checker_unittest.py
@@ -1205,10 +1205,13 @@
       payload_checker = _GetPayloadChecker(payload_gen.WriteToFileWithData,
                                            **kwargs)
 
-      kwargs = {'pubkey_file_name': test_utils._PUBKEY_FILE_NAME,
-                'rootfs_part_size': rootfs_part_size,
-                'metadata_size': metadata_size,
-                'kernel_part_size': kernel_part_size}
+      kwargs = {
+          'pubkey_file_name': test_utils._PUBKEY_FILE_NAME,
+          'metadata_size': metadata_size,
+          'part_sizes': {
+              common.KERNEL: kernel_part_size,
+              common.ROOTFS: rootfs_part_size}}
+
       should_fail = (fail_wrong_payload_type or fail_mismatched_block_size or
                      fail_mismatched_metadata_size or fail_excess_data or
                      fail_rootfs_part_size_exceeded or
diff --git a/scripts/update_payload/payload.py b/scripts/update_payload/payload.py
index 15f66d0..2a0cb58 100644
--- a/scripts/update_payload/payload.py
+++ b/scripts/update_payload/payload.py
@@ -274,8 +274,8 @@
 
   def Check(self, pubkey_file_name=None, metadata_sig_file=None,
             metadata_size=0, report_out_file=None, assert_type=None,
-            block_size=0, rootfs_part_size=0, kernel_part_size=0,
-            allow_unhashed=False, disabled_tests=()):
+            block_size=0, part_sizes=None, allow_unhashed=False,
+            disabled_tests=()):
     """Checks the payload integrity.
 
     Args:
@@ -285,8 +285,7 @@
       report_out_file: file object to dump the report to
       assert_type: assert that payload is either 'full' or 'delta'
       block_size: expected filesystem / payload block size
-      rootfs_part_size: the size of (physical) rootfs partitions in bytes
-      kernel_part_size: the size of (physical) kernel partitions in bytes
+      part_sizes: map of partition label to (physical) size in bytes
       allow_unhashed: allow unhashed operation blobs
       disabled_tests: list of tests to disable
 
@@ -302,20 +301,17 @@
     helper.Run(pubkey_file_name=pubkey_file_name,
                metadata_sig_file=metadata_sig_file,
                metadata_size=metadata_size,
-               rootfs_part_size=rootfs_part_size,
-               kernel_part_size=kernel_part_size,
+               part_sizes=part_sizes,
                report_out_file=report_out_file)
 
-  def Apply(self, new_kernel_part, new_rootfs_part, old_kernel_part=None,
-            old_rootfs_part=None, bsdiff_in_place=True, bspatch_path=None,
-            puffpatch_path=None, truncate_to_expected_size=True):
+  def Apply(self, new_parts, old_parts=None, bsdiff_in_place=True,
+            bspatch_path=None, puffpatch_path=None,
+            truncate_to_expected_size=True):
     """Applies the update payload.
 
     Args:
-      new_kernel_part: name of dest kernel partition file
-      new_rootfs_part: name of dest rootfs partition file
-      old_kernel_part: name of source kernel partition file (optional)
-      old_rootfs_part: name of source rootfs partition file (optional)
+      new_parts: map of partition name to dest partition file
+      old_parts: map of partition name to partition file (optional)
       bsdiff_in_place: whether to perform BSDIFF operations in-place (optional)
       bspatch_path: path to the bspatch binary (optional)
       puffpatch_path: path to the puffpatch binary (optional)
@@ -333,6 +329,4 @@
         self, bsdiff_in_place=bsdiff_in_place, bspatch_path=bspatch_path,
         puffpatch_path=puffpatch_path,
         truncate_to_expected_size=truncate_to_expected_size)
-    helper.Run(new_kernel_part, new_rootfs_part,
-               old_kernel_part=old_kernel_part,
-               old_rootfs_part=old_rootfs_part)
+    helper.Run(new_parts, old_parts=old_parts)