Merge "[release] Add more check for rialto resigning" into main
diff --git a/apex/sign_virt_apex.py b/apex/sign_virt_apex.py
index 0bcbac7..0b6137b 100644
--- a/apex/sign_virt_apex.py
+++ b/apex/sign_virt_apex.py
@@ -208,74 +208,111 @@
     return info, descriptors
 
 
-# Look up a list of (key, value) with a key. Returns the list of value(s) with the matching key.
-# The order of those values is maintained.
-def LookUp(pairs, key):
+def find_all_values_by_key(pairs, key):
+    """Find all the values of the key in the pairs."""
     return [v for (k, v) in pairs if k == key]
 
 # Extract properties from the descriptors of original vbmeta image,
 # append to command as parameter.
 def AppendPropArgument(cmd, descriptors):
-    for prop in LookUp(descriptors, 'Prop'):
+    for prop in find_all_values_by_key(descriptors, 'Prop'):
         cmd.append('--prop')
         result = re.match(r"(.+) -> '(.+)'", prop)
         cmd.append(result.group(1) + ":" + result.group(2))
 
 
-def check_no_size_change_on_resigned_image(image_path, original_image_info, resigned_image_info):
-    assert original_image_info is not None, f'no avbinfo on original image: {image_path}'
-    assert resigned_image_info is not None, f'no avbinfo on resigned image: {image_path}'
-    assert_same_value(original_image_info, resigned_image_info, "Header Block", image_path)
-    assert_same_value(original_image_info, resigned_image_info, "Authentication Block", image_path)
-    assert_same_value(original_image_info, resigned_image_info, "Auxiliary Block", image_path)
+def check_resigned_image_avb_info(image_path, original_info, original_descriptors, args):
+    updated_info, updated_descriptors = AvbInfo(args, image_path)
+    assert original_info is not None, f'no avbinfo on original image: {image_path}'
+    assert updated_info is not None, f'no avbinfo on resigned image: {image_path}'
+    assert_different_value(original_info, updated_info, "Public key (sha1)", image_path)
+    updated_public_key = updated_info.pop("Public key (sha1)")
+    if not hasattr(check_resigned_image_avb_info, "new_public_key"):
+        check_resigned_image_avb_info.new_public_key = updated_public_key
+    else:
+        assert check_resigned_image_avb_info.new_public_key == updated_public_key, \
+            "All images should be resigned with the same public key. Expected public key (sha1):" \
+            f" {check_resigned_image_avb_info.new_public_key}, actual public key (sha1): " \
+            f"{updated_public_key}, Path: {image_path}"
+    original_info.pop("Public key (sha1)")
+    assert original_info == updated_info, \
+        f"Original info and updated info should be the same for {image_path}. " \
+        f"Original info: {original_info}, updated info: {updated_info}"
 
-def AddHashFooter(args, key, image_path, additional_images=None):
+    # Verify the descriptors of the original and updated images.
+    assert len(original_descriptors) == len(updated_descriptors), \
+        f"Number of descriptors should be the same for {image_path}. " \
+        f"Original descriptors: {original_descriptors}, updated descriptors: {updated_descriptors}"
+    original_prop_descriptors = sorted(find_all_values_by_key(original_descriptors, "Prop"))
+    updated_prop_descriptors = sorted(find_all_values_by_key(updated_descriptors, "Prop"))
+    assert original_prop_descriptors == updated_prop_descriptors, \
+        f"Prop descriptors should be the same for {image_path}. " \
+        f"Original prop descriptors: {original_prop_descriptors}, " \
+        f"updated prop descriptors: {updated_prop_descriptors}"
+
+    # Remove digest from hash descriptors before comparing, since some digests should change.
+    original_hash_descriptors = extract_hash_descriptors(original_descriptors, drop_digest)
+    updated_hash_descriptors = extract_hash_descriptors(updated_descriptors, drop_digest)
+    assert original_hash_descriptors == updated_hash_descriptors, \
+        f"Hash descriptors' parameters should be the same for {image_path}. " \
+        f"Original hash descriptors: {original_hash_descriptors}, " \
+        f"updated hash descriptors: {updated_hash_descriptors}"
+
+def drop_digest(descriptor):
+    return {k: v for k, v in descriptor.items() if k != "Digest"}
+
+def AddHashFooter(args, key, image_path, additional_images=()):
     if os.path.basename(image_path) in args.key_overrides:
         key = args.key_overrides[os.path.basename(image_path)]
     info, descriptors = AvbInfo(args, image_path)
-    if info:
-        # Extract hash descriptor of original image.
-        descs = {desc['Partition Name']:desc for desc in LookUp(descriptors, 'Hash descriptor')}
-        if additional_images:
-            for additional_image in additional_images:
-                _, additional_desc = AvbInfo(args, additional_image)
-                del descs[LookUp(additional_desc, 'Hash descriptor')[0]['Partition Name']]
-        assert len(descs) == 1, \
-            f'multiple hash descriptors except additional descriptors exist in {image_path}'
-        original_image_descriptor = list(descs.values())[0]
+    assert info is not None, f'no avbinfo: {image_path}'
 
-        image_size = ReadBytesSize(info['Image size'])
-        algorithm = info['Algorithm']
-        original_image_partition_name = original_image_descriptor['Partition Name']
-        original_image_salt = original_image_descriptor['Salt']
-        partition_size = str(image_size)
+    # Extract hash descriptor of original image.
+    hash_descriptors_original = extract_hash_descriptors(descriptors, drop_digest)
+    for additional_image in additional_images:
+        _, additional_desc = AvbInfo(args, additional_image)
+        hash_descriptors = extract_hash_descriptors(additional_desc, drop_digest)
+        for k, v in hash_descriptors.items():
+            assert v == hash_descriptors_original[k], \
+                f"Hash descriptor of {k} in {additional_image} and {image_path} should be " \
+                f"the same. {additional_image}: {v}, {image_path}: {hash_descriptors_original[k]}"
+            del hash_descriptors_original[k]
+    assert len(hash_descriptors_original) == 1, \
+        f"Only one hash descriptor is expected for {image_path} after removing " \
+        f"additional images. Hash descriptors: {hash_descriptors_original}"
+    [(original_image_partition_name, original_image_descriptor)] = hash_descriptors_original.items()
+    assert info["Original image size"] == original_image_descriptor["Image Size"], \
+        f"Original image size should be the same as the image size in the hash descriptor " \
+        f"for {image_path}. Original image size: {info['Original image size']}, " \
+        f"image size in the hash descriptor: {original_image_descriptor['Image Size']}"
 
-        cmd = ['avbtool', 'add_hash_footer',
-               '--key', key,
-               '--algorithm', algorithm,
-               '--partition_name', original_image_partition_name,
-               '--salt', original_image_salt,
-               '--partition_size', partition_size,
-               '--image', image_path]
-        AppendPropArgument(cmd, descriptors)
-        if args.signing_args:
-            cmd.extend(shlex.split(args.signing_args))
-        if additional_images:
-            for additional_image in additional_images:
-                cmd.extend(['--include_descriptors_from_image', additional_image])
+    partition_size = str(ReadBytesSize(info['Image size']))
+    algorithm = info['Algorithm']
+    original_image_salt = original_image_descriptor['Salt']
 
-        if 'Rollback Index' in info:
-            cmd.extend(['--rollback_index', info['Rollback Index']])
-        RunCommand(args, cmd)
-        resigned_info, _ = AvbInfo(args, image_path)
-        check_no_size_change_on_resigned_image(image_path, info, resigned_info)
+    cmd = ['avbtool', 'add_hash_footer',
+           '--key', key,
+           '--algorithm', algorithm,
+           '--partition_name', original_image_partition_name,
+           '--salt', original_image_salt,
+           '--partition_size', partition_size,
+           '--image', image_path]
+    AppendPropArgument(cmd, descriptors)
+    if args.signing_args:
+        cmd.extend(shlex.split(args.signing_args))
+    for additional_image in additional_images:
+        cmd.extend(['--include_descriptors_from_image', additional_image])
+    cmd.extend(['--rollback_index', info['Rollback Index']])
+
+    RunCommand(args, cmd)
+    check_resigned_image_avb_info(image_path, info, descriptors, args)
 
 def AddHashTreeFooter(args, key, image_path):
     if os.path.basename(image_path) in args.key_overrides:
         key = args.key_overrides[os.path.basename(image_path)]
     info, descriptors = AvbInfo(args, image_path)
     if info:
-        descriptor = LookUp(descriptors, 'Hashtree descriptor')[0]
+        descriptor = find_all_values_by_key(descriptors, 'Hashtree descriptor')[0]
         image_size = ReadBytesSize(info['Image size'])
         algorithm = info['Algorithm']
         partition_name = descriptor['Partition Name']
@@ -295,8 +332,7 @@
         if args.signing_args:
             cmd.extend(shlex.split(args.signing_args))
         RunCommand(args, cmd)
-        resigned_info, _ = AvbInfo(args, image_path)
-        check_no_size_change_on_resigned_image(image_path, info, resigned_info)
+        check_resigned_image_avb_info(image_path, info, descriptors, args)
 
 
 def UpdateVbmetaBootconfig(args, initrds, vbmeta_img):
@@ -412,8 +448,7 @@
             cmd.extend(shlex.split(args.signing_args))
 
         RunCommand(args, cmd)
-        resigned_info, _ = AvbInfo(args, vbmeta_img)
-        check_no_size_change_on_resigned_image(vbmeta_img, info, resigned_info)
+        check_resigned_image_avb_info(vbmeta_img, info, descriptors, args)
         # libavb expects to be able to read the maximum vbmeta size, so we must provide a partition
         # which matches this or the read will fail.
         with open(vbmeta_img, 'a', encoding='utf8') as f:
@@ -531,7 +566,8 @@
         initrd_debug_file = files[initrd_debug]
 
         _, kernel_image_descriptors = AvbInfo(args, kernel_file)
-        salts = {desc['Partition Name']:desc['Salt'] for desc in LookUp(kernel_image_descriptors, 'Hash descriptor')}
+        salts = extract_hash_descriptors(
+            kernel_image_descriptors, lambda descriptor: descriptor['Salt'])
         initrd_normal_hashdesc = tempfile.NamedTemporaryFile(delete=False).name
         initrd_debug_hashdesc = tempfile.NamedTemporaryFile(delete=False).name
         initrd_n_f = Async(GenVbmetaImage, args, initrd_normal_file,
@@ -556,46 +592,53 @@
 
     # Re-sign rialto if it exists. Rialto only exists in arm64 environment.
     if os.path.exists(files['rialto']):
-        update_kernel_hashes_task = Async(
-            update_initrd_hashes_in_rialto, original_kernel_descriptors, args,
+        update_initrd_digests_task = Async(
+            update_initrd_digests_in_rialto, original_kernel_descriptors, args,
             files, wait=[resign_kernel_task])
-        Async(resign_rialto, args, key, files['rialto'], wait=[update_kernel_hashes_task])
+        Async(resign_rialto, args, key, files['rialto'], wait=[update_initrd_digests_task])
 
 def resign_rialto(args, key, rialto_path):
-    original_info, original_descriptors = AvbInfo(args, rialto_path)
+    _, original_descriptors = AvbInfo(args, rialto_path)
     AddHashFooter(args, key, rialto_path)
 
     # Verify the new AVB footer.
     updated_info, updated_descriptors = AvbInfo(args, rialto_path)
-    assert_same_value(original_info, updated_info, "Original image size", "rialto")
-    original_descriptor = find_hash_descriptor_by_partition_name(original_descriptors, 'boot')
-    updated_descriptor = find_hash_descriptor_by_partition_name(updated_descriptors, 'boot')
+    assert len(updated_descriptors) == 2, \
+        f"There should be two descriptors for rialto. Updated descriptors: {updated_descriptors}"
+    updated_prop = find_all_values_by_key(updated_descriptors, "Prop")
+    assert len(updated_prop) == 1, "There should be only one Prop descriptor for rialto. " \
+        f"Updated descriptors: {updated_descriptors}"
+    assert updated_info["Rollback Index"] != "0", "Rollback index should not be zero for rialto."
+
+    # Verify the only hash descriptor of rialto.
+    updated_hash_descriptors = extract_hash_descriptors(updated_descriptors)
+    assert len(updated_hash_descriptors) == 1, \
+        f"There should be only one hash descriptor for rialto. " \
+        f"Updated hash descriptors: {updated_hash_descriptors}"
     # Since salt is not updated, the change of digest reflects the change of content of rialto
     # kernel.
-    assert_same_value(original_descriptor, updated_descriptor, "Salt", "rialto_hash_descriptor")
     if not args.do_not_update_bootconfigs:
+        [(_, original_descriptor)] = extract_hash_descriptors(original_descriptors).items()
+        [(_, updated_descriptor)] = updated_hash_descriptors.items()
         assert_different_value(original_descriptor, updated_descriptor, "Digest",
                                "rialto_hash_descriptor")
 
-def assert_same_value(original, updated, key, context):
-    assert original[key] == updated[key], \
-        f"Value of '{key}' should not change for '{context}'" \
-        f"Original value: {original[key]}, updated value: {updated[key]}"
-
 def assert_different_value(original, updated, key, context):
     assert original[key] != updated[key], \
         f"Value of '{key}' should change for '{context}'" \
         f"Original value: {original[key]}, updated value: {updated[key]}"
 
-def update_initrd_hashes_in_rialto(original_descriptors, args, files):
+def update_initrd_digests_in_rialto(original_descriptors, args, files):
     _, updated_descriptors = AvbInfo(args, files['kernel'])
 
-    original_kernel_descriptor = find_hash_descriptor_by_partition_name(
-        original_descriptors, 'boot')
-    updated_kernel_descriptor = find_hash_descriptor_by_partition_name(
-        updated_descriptors, 'boot')
-    assert_same_value(original_kernel_descriptor, updated_kernel_descriptor,
-                      "Digest", "microdroid_kernel_hash_descriptor")
+    original_digests = extract_hash_descriptors(
+        original_descriptors, lambda x: binascii.unhexlify(x['Digest']))
+    updated_digests = extract_hash_descriptors(
+        updated_descriptors, lambda x: binascii.unhexlify(x['Digest']))
+    assert original_digests.pop("boot") == updated_digests.pop("boot"), \
+        "Hash descriptor of boot should not change for kernel. " \
+        f"Original descriptors: {original_descriptors}, " \
+        f"updated descriptors: {updated_descriptors}"
 
     # Update the hashes of initrd_normal and initrd_debug in rialto if the
     # bootconfigs in them are updated.
@@ -605,47 +648,35 @@
     with open(files['rialto'], "rb") as file:
         content = file.read()
 
-    partition_names = ['initrd_normal', 'initrd_debug']
-    all_digests = set()
-    for partition_name in partition_names:
-        original_descriptor = find_hash_descriptor_by_partition_name(
-            original_descriptors, partition_name)
-        update_descriptor = find_hash_descriptor_by_partition_name(
-            updated_descriptors, partition_name)
+    # Check that the original and updated digests are different before updating rialto.
+    partition_names = {'initrd_normal', 'initrd_debug'}
+    assert set(original_digests.keys()) == set(updated_digests.keys()) == partition_names, \
+        f"Original digests' partitions should be {partition_names}. " \
+        f"Original digests: {original_digests}. Updated digests: {updated_digests}"
+    assert set(original_digests.values()).isdisjoint(updated_digests.values()), \
+        "Digests of initrd_normal and initrd_debug should change. " \
+        f"Original descriptors: {original_descriptors}, " \
+        f"updated descriptors: {updated_descriptors}"
 
-        original_digest = binascii.unhexlify(original_descriptor['Digest'])
-        updated_digest = binascii.unhexlify(update_descriptor['Digest'])
+    for partition_name, original_digest in original_digests.items():
+        updated_digest = updated_digests[partition_name]
         assert len(original_digest) == len(updated_digest), \
             f"Length of original_digest and updated_digest must be the same for {partition_name}." \
             f" Original digest: {original_digest}, updated digest: {updated_digest}"
-        assert original_digest != updated_digest, \
-            f"Digest of the partition {partition_name} should change. " \
-            f"Original digest: {original_digest}, updated digest: {updated_digest}"
-        all_digests.add(original_digest)
-        all_digests.add(updated_digest)
 
         new_content = content.replace(original_digest, updated_digest)
         assert len(new_content) == len(content), \
             "Length of new_content and content must be the same."
         assert new_content != content, \
-            f"original digest of the partition {partition_name} not found. " \
-            f"Original descriptor: {original_descriptor}"
+            f"original digest of the partition {partition_name} not found."
         content = new_content
 
-    assert len(all_digests) == len(partition_names) * 2, \
-        f"There should be {len(partition_names) * 2} different digests for the original and " \
-        f"updated descriptors for initrd. Original descriptors: {original_descriptors}, " \
-        f"updated descriptors: {updated_descriptors}"
-
     with open(files['rialto'], "wb") as file:
         file.write(content)
 
-def find_hash_descriptor_by_partition_name(descriptors, partition_name):
-    """Find the hash descriptor of the partition in the descriptors."""
-    for descriptor_type, descriptor in descriptors:
-        if descriptor_type == 'Hash descriptor' and descriptor['Partition Name'] == partition_name:
-            return descriptor
-    assert False, f'Failed to find hash descriptor for partition {partition_name}'
+def extract_hash_descriptors(descriptors, f=lambda x: x):
+    return {desc["Partition Name"]: f(desc) for desc in
+            find_all_values_by_key(descriptors, "Hash descriptor")}
 
 def VerifyVirtApex(args):
     key = args.key