Merge "Use MissingAction for building constraint_spec" into main
diff --git a/apex/sign_virt_apex.py b/apex/sign_virt_apex.py
index 5758214..0bcbac7 100644
--- a/apex/sign_virt_apex.py
+++ b/apex/sign_virt_apex.py
@@ -27,6 +27,7 @@
 - lpmake, lpunpack, simg2img, img2simg, initrd_bootconfig
 """
 import argparse
+import binascii
 import builtins
 import hashlib
 import os
@@ -224,9 +225,9 @@
 def check_no_size_change_on_resigned_image(image_path, original_image_info, resigned_image_info):
     assert original_image_info is not None, f'no avbinfo on original image: {image_path}'
     assert resigned_image_info is not None, f'no avbinfo on resigned image: {image_path}'
-    assert original_image_info['Header Block'] == resigned_image_info['Header Block'], f'header block size mismatch: {image_path}'
-    assert original_image_info['Authentication Block'] == resigned_image_info['Authentication Block'], f'authentication block size mismatch: {image_path}'
-    assert original_image_info['Auxiliary Block'] == resigned_image_info['Auxiliary Block'], f'auxiliary block size mismatch: {image_path}'
+    assert_same_value(original_image_info, resigned_image_info, "Header Block", image_path)
+    assert_same_value(original_image_info, resigned_image_info, "Authentication Block", image_path)
+    assert_same_value(original_image_info, resigned_image_info, "Auxiliary Block", image_path)
 
 def AddHashFooter(args, key, image_path, additional_images=None):
     if os.path.basename(image_path) in args.key_overrides:
@@ -539,12 +540,12 @@
         initrd_d_f = Async(GenVbmetaImage, args, initrd_debug_file,
                            initrd_debug_hashdesc, "initrd_debug", salts["initrd_debug"],
                            wait=[vbmeta_bc_f] if vbmeta_bc_f is not None else [])
-        Async(AddHashFooter, args, key, kernel_file,
-              additional_images=[
-                  initrd_normal_hashdesc, initrd_debug_hashdesc],
+        return Async(AddHashFooter, args, key, kernel_file,
+              additional_images=[initrd_normal_hashdesc, initrd_debug_hashdesc],
               wait=[initrd_n_f, initrd_d_f])
 
-    resign_kernel('kernel', 'initrd_normal.img', 'initrd_debuggable.img')
+    _, original_kernel_descriptors = AvbInfo(args, files['kernel'])
+    resign_kernel_task = resign_kernel('kernel', 'initrd_normal.img', 'initrd_debuggable.img')
 
     for ver in gki_versions:
         if f'gki-{ver}_kernel' in files:
@@ -555,8 +556,96 @@
 
     # Re-sign rialto if it exists. Rialto only exists in arm64 environment.
     if os.path.exists(files['rialto']):
-        Async(AddHashFooter, args, key, files['rialto'])
+        update_kernel_hashes_task = Async(
+            update_initrd_hashes_in_rialto, original_kernel_descriptors, args,
+            files, wait=[resign_kernel_task])
+        Async(resign_rialto, args, key, files['rialto'], wait=[update_kernel_hashes_task])
 
+def resign_rialto(args, key, rialto_path):
+    original_info, original_descriptors = AvbInfo(args, rialto_path)
+    AddHashFooter(args, key, rialto_path)
+
+    # Verify the new AVB footer.
+    updated_info, updated_descriptors = AvbInfo(args, rialto_path)
+    assert_same_value(original_info, updated_info, "Original image size", "rialto")
+    original_descriptor = find_hash_descriptor_by_partition_name(original_descriptors, 'boot')
+    updated_descriptor = find_hash_descriptor_by_partition_name(updated_descriptors, 'boot')
+    # Since salt is not updated, the change of digest reflects the change of content of rialto
+    # kernel.
+    assert_same_value(original_descriptor, updated_descriptor, "Salt", "rialto_hash_descriptor")
+    if not args.do_not_update_bootconfigs:
+        assert_different_value(original_descriptor, updated_descriptor, "Digest",
+                               "rialto_hash_descriptor")
+
+def assert_same_value(original, updated, key, context):
+    assert original[key] == updated[key], \
+        f"Value of '{key}' should not change for '{context}'" \
+        f"Original value: {original[key]}, updated value: {updated[key]}"
+
+def assert_different_value(original, updated, key, context):
+    assert original[key] != updated[key], \
+        f"Value of '{key}' should change for '{context}'" \
+        f"Original value: {original[key]}, updated value: {updated[key]}"
+
+def update_initrd_hashes_in_rialto(original_descriptors, args, files):
+    _, updated_descriptors = AvbInfo(args, files['kernel'])
+
+    original_kernel_descriptor = find_hash_descriptor_by_partition_name(
+        original_descriptors, 'boot')
+    updated_kernel_descriptor = find_hash_descriptor_by_partition_name(
+        updated_descriptors, 'boot')
+    assert_same_value(original_kernel_descriptor, updated_kernel_descriptor,
+                      "Digest", "microdroid_kernel_hash_descriptor")
+
+    # Update the hashes of initrd_normal and initrd_debug in rialto if the
+    # bootconfigs in them are updated.
+    if args.do_not_update_bootconfigs:
+        return
+
+    with open(files['rialto'], "rb") as file:
+        content = file.read()
+
+    partition_names = ['initrd_normal', 'initrd_debug']
+    all_digests = set()
+    for partition_name in partition_names:
+        original_descriptor = find_hash_descriptor_by_partition_name(
+            original_descriptors, partition_name)
+        update_descriptor = find_hash_descriptor_by_partition_name(
+            updated_descriptors, partition_name)
+
+        original_digest = binascii.unhexlify(original_descriptor['Digest'])
+        updated_digest = binascii.unhexlify(update_descriptor['Digest'])
+        assert len(original_digest) == len(updated_digest), \
+            f"Length of original_digest and updated_digest must be the same for {partition_name}." \
+            f" Original digest: {original_digest}, updated digest: {updated_digest}"
+        assert original_digest != updated_digest, \
+            f"Digest of the partition {partition_name} should change. " \
+            f"Original digest: {original_digest}, updated digest: {updated_digest}"
+        all_digests.add(original_digest)
+        all_digests.add(updated_digest)
+
+        new_content = content.replace(original_digest, updated_digest)
+        assert len(new_content) == len(content), \
+            "Length of new_content and content must be the same."
+        assert new_content != content, \
+            f"original digest of the partition {partition_name} not found. " \
+            f"Original descriptor: {original_descriptor}"
+        content = new_content
+
+    assert len(all_digests) == len(partition_names) * 2, \
+        f"There should be {len(partition_names) * 2} different digests for the original and " \
+        f"updated descriptors for initrd. Original descriptors: {original_descriptors}, " \
+        f"updated descriptors: {updated_descriptors}"
+
+    with open(files['rialto'], "wb") as file:
+        file.write(content)
+
+def find_hash_descriptor_by_partition_name(descriptors, partition_name):
+    """Find the hash descriptor of the partition in the descriptors."""
+    for descriptor_type, descriptor in descriptors:
+        if descriptor_type == 'Hash descriptor' and descriptor['Partition Name'] == partition_name:
+            return descriptor
+    assert False, f'Failed to find hash descriptor for partition {partition_name}'
 
 def VerifyVirtApex(args):
     key = args.key