[release] Add more check for rialto resigning

This cl introduces additional checks for the AVB footer generated
during the rialto resigning. The checks serve to ensure the
correctness of the rialto resigning.

The cl also simplifies the code with some refactoring.

Bug: 319235308
Test: atest --host sign_virt_apex_test
Test: atest MicrodroidHostTests
Change-Id: I83f22e603b21a50f3099b231936ef221c5159b99
diff --git a/apex/sign_virt_apex.py b/apex/sign_virt_apex.py
index 0bcbac7..0b6137b 100644
--- a/apex/sign_virt_apex.py
+++ b/apex/sign_virt_apex.py
@@ -208,74 +208,111 @@
     return info, descriptors
 
 
-# Look up a list of (key, value) with a key. Returns the list of value(s) with the matching key.
-# The order of those values is maintained.
-def LookUp(pairs, key):
+def find_all_values_by_key(pairs, key):
+    """Find all the values of the key in the pairs."""
     return [v for (k, v) in pairs if k == key]
 
 # Extract properties from the descriptors of original vbmeta image,
 # append to command as parameter.
 def AppendPropArgument(cmd, descriptors):
-    for prop in LookUp(descriptors, 'Prop'):
+    for prop in find_all_values_by_key(descriptors, 'Prop'):
         cmd.append('--prop')
         result = re.match(r"(.+) -> '(.+)'", prop)
         cmd.append(result.group(1) + ":" + result.group(2))
 
 
-def check_no_size_change_on_resigned_image(image_path, original_image_info, resigned_image_info):
-    assert original_image_info is not None, f'no avbinfo on original image: {image_path}'
-    assert resigned_image_info is not None, f'no avbinfo on resigned image: {image_path}'
-    assert_same_value(original_image_info, resigned_image_info, "Header Block", image_path)
-    assert_same_value(original_image_info, resigned_image_info, "Authentication Block", image_path)
-    assert_same_value(original_image_info, resigned_image_info, "Auxiliary Block", image_path)
+def check_resigned_image_avb_info(image_path, original_info, original_descriptors, args):
+    updated_info, updated_descriptors = AvbInfo(args, image_path)
+    assert original_info is not None, f'no avbinfo on original image: {image_path}'
+    assert updated_info is not None, f'no avbinfo on resigned image: {image_path}'
+    assert_different_value(original_info, updated_info, "Public key (sha1)", image_path)
+    updated_public_key = updated_info.pop("Public key (sha1)")
+    if not hasattr(check_resigned_image_avb_info, "new_public_key"):
+        check_resigned_image_avb_info.new_public_key = updated_public_key
+    else:
+        assert check_resigned_image_avb_info.new_public_key == updated_public_key, \
+            "All images should be resigned with the same public key. Expected public key (sha1):" \
+            f" {check_resigned_image_avb_info.new_public_key}, actual public key (sha1): " \
+            f"{updated_public_key}, Path: {image_path}"
+    original_info.pop("Public key (sha1)")
+    assert original_info == updated_info, \
+        f"Original info and updated info should be the same for {image_path}. " \
+        f"Original info: {original_info}, updated info: {updated_info}"
 
-def AddHashFooter(args, key, image_path, additional_images=None):
+    # Verify the descriptors of the original and updated images.
+    assert len(original_descriptors) == len(updated_descriptors), \
+        f"Number of descriptors should be the same for {image_path}. " \
+        f"Original descriptors: {original_descriptors}, updated descriptors: {updated_descriptors}"
+    original_prop_descriptors = sorted(find_all_values_by_key(original_descriptors, "Prop"))
+    updated_prop_descriptors = sorted(find_all_values_by_key(updated_descriptors, "Prop"))
+    assert original_prop_descriptors == updated_prop_descriptors, \
+        f"Prop descriptors should be the same for {image_path}. " \
+        f"Original prop descriptors: {original_prop_descriptors}, " \
+        f"updated prop descriptors: {updated_prop_descriptors}"
+
+    # Remove digest from hash descriptors before comparing, since some digests should change.
+    original_hash_descriptors = extract_hash_descriptors(original_descriptors, drop_digest)
+    updated_hash_descriptors = extract_hash_descriptors(updated_descriptors, drop_digest)
+    assert original_hash_descriptors == updated_hash_descriptors, \
+        f"Hash descriptors' parameters should be the same for {image_path}. " \
+        f"Original hash descriptors: {original_hash_descriptors}, " \
+        f"updated hash descriptors: {updated_hash_descriptors}"
+
+def drop_digest(descriptor):
+    return {k: v for k, v in descriptor.items() if k != "Digest"}
+
+def AddHashFooter(args, key, image_path, additional_images=()):
     if os.path.basename(image_path) in args.key_overrides:
         key = args.key_overrides[os.path.basename(image_path)]
     info, descriptors = AvbInfo(args, image_path)
-    if info:
-        # Extract hash descriptor of original image.
-        descs = {desc['Partition Name']:desc for desc in LookUp(descriptors, 'Hash descriptor')}
-        if additional_images:
-            for additional_image in additional_images:
-                _, additional_desc = AvbInfo(args, additional_image)
-                del descs[LookUp(additional_desc, 'Hash descriptor')[0]['Partition Name']]
-        assert len(descs) == 1, \
-            f'multiple hash descriptors except additional descriptors exist in {image_path}'
-        original_image_descriptor = list(descs.values())[0]
+    assert info is not None, f'no avbinfo: {image_path}'
 
-        image_size = ReadBytesSize(info['Image size'])
-        algorithm = info['Algorithm']
-        original_image_partition_name = original_image_descriptor['Partition Name']
-        original_image_salt = original_image_descriptor['Salt']
-        partition_size = str(image_size)
+    # Extract hash descriptor of original image.
+    hash_descriptors_original = extract_hash_descriptors(descriptors, drop_digest)
+    for additional_image in additional_images:
+        _, additional_desc = AvbInfo(args, additional_image)
+        hash_descriptors = extract_hash_descriptors(additional_desc, drop_digest)
+        for k, v in hash_descriptors.items():
+            assert v == hash_descriptors_original[k], \
+                f"Hash descriptor of {k} in {additional_image} and {image_path} should be " \
+                f"the same. {additional_image}: {v}, {image_path}: {hash_descriptors_original[k]}"
+            del hash_descriptors_original[k]
+    assert len(hash_descriptors_original) == 1, \
+        f"Only one hash descriptor is expected for {image_path} after removing " \
+        f"additional images. Hash descriptors: {hash_descriptors_original}"
+    [(original_image_partition_name, original_image_descriptor)] = hash_descriptors_original.items()
+    assert info["Original image size"] == original_image_descriptor["Image Size"], \
+        f"Original image size should be the same as the image size in the hash descriptor " \
+        f"for {image_path}. Original image size: {info['Original image size']}, " \
+        f"image size in the hash descriptor: {original_image_descriptor['Image Size']}"
 
-        cmd = ['avbtool', 'add_hash_footer',
-               '--key', key,
-               '--algorithm', algorithm,
-               '--partition_name', original_image_partition_name,
-               '--salt', original_image_salt,
-               '--partition_size', partition_size,
-               '--image', image_path]
-        AppendPropArgument(cmd, descriptors)
-        if args.signing_args:
-            cmd.extend(shlex.split(args.signing_args))
-        if additional_images:
-            for additional_image in additional_images:
-                cmd.extend(['--include_descriptors_from_image', additional_image])
+    partition_size = str(ReadBytesSize(info['Image size']))
+    algorithm = info['Algorithm']
+    original_image_salt = original_image_descriptor['Salt']
 
-        if 'Rollback Index' in info:
-            cmd.extend(['--rollback_index', info['Rollback Index']])
-        RunCommand(args, cmd)
-        resigned_info, _ = AvbInfo(args, image_path)
-        check_no_size_change_on_resigned_image(image_path, info, resigned_info)
+    cmd = ['avbtool', 'add_hash_footer',
+           '--key', key,
+           '--algorithm', algorithm,
+           '--partition_name', original_image_partition_name,
+           '--salt', original_image_salt,
+           '--partition_size', partition_size,
+           '--image', image_path]
+    AppendPropArgument(cmd, descriptors)
+    if args.signing_args:
+        cmd.extend(shlex.split(args.signing_args))
+    for additional_image in additional_images:
+        cmd.extend(['--include_descriptors_from_image', additional_image])
+    cmd.extend(['--rollback_index', info['Rollback Index']])
+
+    RunCommand(args, cmd)
+    check_resigned_image_avb_info(image_path, info, descriptors, args)
 
 def AddHashTreeFooter(args, key, image_path):
     if os.path.basename(image_path) in args.key_overrides:
         key = args.key_overrides[os.path.basename(image_path)]
     info, descriptors = AvbInfo(args, image_path)
     if info:
-        descriptor = LookUp(descriptors, 'Hashtree descriptor')[0]
+        descriptor = find_all_values_by_key(descriptors, 'Hashtree descriptor')[0]
         image_size = ReadBytesSize(info['Image size'])
         algorithm = info['Algorithm']
         partition_name = descriptor['Partition Name']
@@ -295,8 +332,7 @@
         if args.signing_args:
             cmd.extend(shlex.split(args.signing_args))
         RunCommand(args, cmd)
-        resigned_info, _ = AvbInfo(args, image_path)
-        check_no_size_change_on_resigned_image(image_path, info, resigned_info)
+        check_resigned_image_avb_info(image_path, info, descriptors, args)
 
 
 def UpdateVbmetaBootconfig(args, initrds, vbmeta_img):
@@ -412,8 +448,7 @@
             cmd.extend(shlex.split(args.signing_args))
 
         RunCommand(args, cmd)
-        resigned_info, _ = AvbInfo(args, vbmeta_img)
-        check_no_size_change_on_resigned_image(vbmeta_img, info, resigned_info)
+        check_resigned_image_avb_info(vbmeta_img, info, descriptors, args)
         # libavb expects to be able to read the maximum vbmeta size, so we must provide a partition
         # which matches this or the read will fail.
         with open(vbmeta_img, 'a', encoding='utf8') as f:
@@ -531,7 +566,8 @@
         initrd_debug_file = files[initrd_debug]
 
         _, kernel_image_descriptors = AvbInfo(args, kernel_file)
-        salts = {desc['Partition Name']:desc['Salt'] for desc in LookUp(kernel_image_descriptors, 'Hash descriptor')}
+        salts = extract_hash_descriptors(
+            kernel_image_descriptors, lambda descriptor: descriptor['Salt'])
         initrd_normal_hashdesc = tempfile.NamedTemporaryFile(delete=False).name
         initrd_debug_hashdesc = tempfile.NamedTemporaryFile(delete=False).name
         initrd_n_f = Async(GenVbmetaImage, args, initrd_normal_file,
@@ -556,46 +592,53 @@
 
     # Re-sign rialto if it exists. Rialto only exists in arm64 environment.
     if os.path.exists(files['rialto']):
-        update_kernel_hashes_task = Async(
-            update_initrd_hashes_in_rialto, original_kernel_descriptors, args,
+        update_initrd_digests_task = Async(
+            update_initrd_digests_in_rialto, original_kernel_descriptors, args,
             files, wait=[resign_kernel_task])
-        Async(resign_rialto, args, key, files['rialto'], wait=[update_kernel_hashes_task])
+        Async(resign_rialto, args, key, files['rialto'], wait=[update_initrd_digests_task])
 
 def resign_rialto(args, key, rialto_path):
-    original_info, original_descriptors = AvbInfo(args, rialto_path)
+    _, original_descriptors = AvbInfo(args, rialto_path)
     AddHashFooter(args, key, rialto_path)
 
     # Verify the new AVB footer.
     updated_info, updated_descriptors = AvbInfo(args, rialto_path)
-    assert_same_value(original_info, updated_info, "Original image size", "rialto")
-    original_descriptor = find_hash_descriptor_by_partition_name(original_descriptors, 'boot')
-    updated_descriptor = find_hash_descriptor_by_partition_name(updated_descriptors, 'boot')
+    assert len(updated_descriptors) == 2, \
+        f"There should be two descriptors for rialto. Updated descriptors: {updated_descriptors}"
+    updated_prop = find_all_values_by_key(updated_descriptors, "Prop")
+    assert len(updated_prop) == 1, "There should be only one Prop descriptor for rialto. " \
+        f"Updated descriptors: {updated_descriptors}"
+    assert updated_info["Rollback Index"] != "0", "Rollback index should not be zero for rialto."
+
+    # Verify the only hash descriptor of rialto.
+    updated_hash_descriptors = extract_hash_descriptors(updated_descriptors)
+    assert len(updated_hash_descriptors) == 1, \
+        f"There should be only one hash descriptor for rialto. " \
+        f"Updated hash descriptors: {updated_hash_descriptors}"
     # Since salt is not updated, the change of digest reflects the change of content of rialto
     # kernel.
-    assert_same_value(original_descriptor, updated_descriptor, "Salt", "rialto_hash_descriptor")
     if not args.do_not_update_bootconfigs:
+        [(_, original_descriptor)] = extract_hash_descriptors(original_descriptors).items()
+        [(_, updated_descriptor)] = updated_hash_descriptors.items()
         assert_different_value(original_descriptor, updated_descriptor, "Digest",
                                "rialto_hash_descriptor")
 
-def assert_same_value(original, updated, key, context):
-    assert original[key] == updated[key], \
-        f"Value of '{key}' should not change for '{context}'" \
-        f"Original value: {original[key]}, updated value: {updated[key]}"
-
 def assert_different_value(original, updated, key, context):
     assert original[key] != updated[key], \
         f"Value of '{key}' should change for '{context}'" \
         f"Original value: {original[key]}, updated value: {updated[key]}"
 
-def update_initrd_hashes_in_rialto(original_descriptors, args, files):
+def update_initrd_digests_in_rialto(original_descriptors, args, files):
     _, updated_descriptors = AvbInfo(args, files['kernel'])
 
-    original_kernel_descriptor = find_hash_descriptor_by_partition_name(
-        original_descriptors, 'boot')
-    updated_kernel_descriptor = find_hash_descriptor_by_partition_name(
-        updated_descriptors, 'boot')
-    assert_same_value(original_kernel_descriptor, updated_kernel_descriptor,
-                      "Digest", "microdroid_kernel_hash_descriptor")
+    original_digests = extract_hash_descriptors(
+        original_descriptors, lambda x: binascii.unhexlify(x['Digest']))
+    updated_digests = extract_hash_descriptors(
+        updated_descriptors, lambda x: binascii.unhexlify(x['Digest']))
+    assert original_digests.pop("boot") == updated_digests.pop("boot"), \
+        "Hash descriptor of boot should not change for kernel. " \
+        f"Original descriptors: {original_descriptors}, " \
+        f"updated descriptors: {updated_descriptors}"
 
     # Update the hashes of initrd_normal and initrd_debug in rialto if the
     # bootconfigs in them are updated.
@@ -605,47 +648,35 @@
     with open(files['rialto'], "rb") as file:
         content = file.read()
 
-    partition_names = ['initrd_normal', 'initrd_debug']
-    all_digests = set()
-    for partition_name in partition_names:
-        original_descriptor = find_hash_descriptor_by_partition_name(
-            original_descriptors, partition_name)
-        update_descriptor = find_hash_descriptor_by_partition_name(
-            updated_descriptors, partition_name)
+    # Check that the original and updated digests are different before updating rialto.
+    partition_names = {'initrd_normal', 'initrd_debug'}
+    assert set(original_digests.keys()) == set(updated_digests.keys()) == partition_names, \
+        f"Original digests' partitions should be {partition_names}. " \
+        f"Original digests: {original_digests}. Updated digests: {updated_digests}"
+    assert set(original_digests.values()).isdisjoint(updated_digests.values()), \
+        "Digests of initrd_normal and initrd_debug should change. " \
+        f"Original descriptors: {original_descriptors}, " \
+        f"updated descriptors: {updated_descriptors}"
 
-        original_digest = binascii.unhexlify(original_descriptor['Digest'])
-        updated_digest = binascii.unhexlify(update_descriptor['Digest'])
+    for partition_name, original_digest in original_digests.items():
+        updated_digest = updated_digests[partition_name]
         assert len(original_digest) == len(updated_digest), \
             f"Length of original_digest and updated_digest must be the same for {partition_name}." \
             f" Original digest: {original_digest}, updated digest: {updated_digest}"
-        assert original_digest != updated_digest, \
-            f"Digest of the partition {partition_name} should change. " \
-            f"Original digest: {original_digest}, updated digest: {updated_digest}"
-        all_digests.add(original_digest)
-        all_digests.add(updated_digest)
 
         new_content = content.replace(original_digest, updated_digest)
         assert len(new_content) == len(content), \
             "Length of new_content and content must be the same."
         assert new_content != content, \
-            f"original digest of the partition {partition_name} not found. " \
-            f"Original descriptor: {original_descriptor}"
+            f"original digest of the partition {partition_name} not found."
         content = new_content
 
-    assert len(all_digests) == len(partition_names) * 2, \
-        f"There should be {len(partition_names) * 2} different digests for the original and " \
-        f"updated descriptors for initrd. Original descriptors: {original_descriptors}, " \
-        f"updated descriptors: {updated_descriptors}"
-
     with open(files['rialto'], "wb") as file:
         file.write(content)
 
-def find_hash_descriptor_by_partition_name(descriptors, partition_name):
-    """Find the hash descriptor of the partition in the descriptors."""
-    for descriptor_type, descriptor in descriptors:
-        if descriptor_type == 'Hash descriptor' and descriptor['Partition Name'] == partition_name:
-            return descriptor
-    assert False, f'Failed to find hash descriptor for partition {partition_name}'
+def extract_hash_descriptors(descriptors, f=lambda x: x):
+    return {desc["Partition Name"]: f(desc) for desc in
+            find_all_values_by_key(descriptors, "Hash descriptor")}
 
 def VerifyVirtApex(args):
     key = args.key