[attestation] Update initrd hashes in rialto.bin

This cl updates the initrd hashes embedded in the rialto.bin
after resigning as the resigning process changes the content of
initrd. Thus the initrd hashed embedded in the rialto.bin also
needs to be updated after the resigning.

Bug: 319235308
Test: atest --host sign_virt_apex_test
Change-Id: I5e9752dd6575e367409b04968107edef67b544f4
diff --git a/apex/sign_virt_apex.py b/apex/sign_virt_apex.py
index 5758214..0bcbac7 100644
--- a/apex/sign_virt_apex.py
+++ b/apex/sign_virt_apex.py
@@ -27,6 +27,7 @@
 - lpmake, lpunpack, simg2img, img2simg, initrd_bootconfig
 """
 import argparse
+import binascii
 import builtins
 import hashlib
 import os
@@ -224,9 +225,9 @@
 def check_no_size_change_on_resigned_image(image_path, original_image_info, resigned_image_info):
     assert original_image_info is not None, f'no avbinfo on original image: {image_path}'
     assert resigned_image_info is not None, f'no avbinfo on resigned image: {image_path}'
-    assert original_image_info['Header Block'] == resigned_image_info['Header Block'], f'header block size mismatch: {image_path}'
-    assert original_image_info['Authentication Block'] == resigned_image_info['Authentication Block'], f'authentication block size mismatch: {image_path}'
-    assert original_image_info['Auxiliary Block'] == resigned_image_info['Auxiliary Block'], f'auxiliary block size mismatch: {image_path}'
+    assert_same_value(original_image_info, resigned_image_info, "Header Block", image_path)
+    assert_same_value(original_image_info, resigned_image_info, "Authentication Block", image_path)
+    assert_same_value(original_image_info, resigned_image_info, "Auxiliary Block", image_path)
 
 def AddHashFooter(args, key, image_path, additional_images=None):
     if os.path.basename(image_path) in args.key_overrides:
@@ -539,12 +540,12 @@
         initrd_d_f = Async(GenVbmetaImage, args, initrd_debug_file,
                            initrd_debug_hashdesc, "initrd_debug", salts["initrd_debug"],
                            wait=[vbmeta_bc_f] if vbmeta_bc_f is not None else [])
-        Async(AddHashFooter, args, key, kernel_file,
-              additional_images=[
-                  initrd_normal_hashdesc, initrd_debug_hashdesc],
+        return Async(AddHashFooter, args, key, kernel_file,
+              additional_images=[initrd_normal_hashdesc, initrd_debug_hashdesc],
               wait=[initrd_n_f, initrd_d_f])
 
-    resign_kernel('kernel', 'initrd_normal.img', 'initrd_debuggable.img')
+    _, original_kernel_descriptors = AvbInfo(args, files['kernel'])
+    resign_kernel_task = resign_kernel('kernel', 'initrd_normal.img', 'initrd_debuggable.img')
 
     for ver in gki_versions:
         if f'gki-{ver}_kernel' in files:
@@ -555,8 +556,96 @@
 
     # Re-sign rialto if it exists. Rialto only exists in arm64 environment.
     if os.path.exists(files['rialto']):
-        Async(AddHashFooter, args, key, files['rialto'])
+        update_kernel_hashes_task = Async(
+            update_initrd_hashes_in_rialto, original_kernel_descriptors, args,
+            files, wait=[resign_kernel_task])
+        Async(resign_rialto, args, key, files['rialto'], wait=[update_kernel_hashes_task])
 
+def resign_rialto(args, key, rialto_path):
+    original_info, original_descriptors = AvbInfo(args, rialto_path)
+    AddHashFooter(args, key, rialto_path)
+
+    # Verify the new AVB footer.
+    updated_info, updated_descriptors = AvbInfo(args, rialto_path)
+    assert_same_value(original_info, updated_info, "Original image size", "rialto")
+    original_descriptor = find_hash_descriptor_by_partition_name(original_descriptors, 'boot')
+    updated_descriptor = find_hash_descriptor_by_partition_name(updated_descriptors, 'boot')
+    # Since salt is not updated, the change of digest reflects the change of content of rialto
+    # kernel.
+    assert_same_value(original_descriptor, updated_descriptor, "Salt", "rialto_hash_descriptor")
+    if not args.do_not_update_bootconfigs:
+        assert_different_value(original_descriptor, updated_descriptor, "Digest",
+                               "rialto_hash_descriptor")
+
+def assert_same_value(original, updated, key, context):
+    assert original[key] == updated[key], \
+        f"Value of '{key}' should not change for '{context}'" \
+        f"Original value: {original[key]}, updated value: {updated[key]}"
+
+def assert_different_value(original, updated, key, context):
+    assert original[key] != updated[key], \
+        f"Value of '{key}' should change for '{context}'" \
+        f"Original value: {original[key]}, updated value: {updated[key]}"
+
+def update_initrd_hashes_in_rialto(original_descriptors, args, files):
+    _, updated_descriptors = AvbInfo(args, files['kernel'])
+
+    original_kernel_descriptor = find_hash_descriptor_by_partition_name(
+        original_descriptors, 'boot')
+    updated_kernel_descriptor = find_hash_descriptor_by_partition_name(
+        updated_descriptors, 'boot')
+    assert_same_value(original_kernel_descriptor, updated_kernel_descriptor,
+                      "Digest", "microdroid_kernel_hash_descriptor")
+
+    # Update the hashes of initrd_normal and initrd_debug in rialto if the
+    # bootconfigs in them are updated.
+    if args.do_not_update_bootconfigs:
+        return
+
+    with open(files['rialto'], "rb") as file:
+        content = file.read()
+
+    partition_names = ['initrd_normal', 'initrd_debug']
+    all_digests = set()
+    for partition_name in partition_names:
+        original_descriptor = find_hash_descriptor_by_partition_name(
+            original_descriptors, partition_name)
+        update_descriptor = find_hash_descriptor_by_partition_name(
+            updated_descriptors, partition_name)
+
+        original_digest = binascii.unhexlify(original_descriptor['Digest'])
+        updated_digest = binascii.unhexlify(update_descriptor['Digest'])
+        assert len(original_digest) == len(updated_digest), \
+            f"Length of original_digest and updated_digest must be the same for {partition_name}." \
+            f" Original digest: {original_digest}, updated digest: {updated_digest}"
+        assert original_digest != updated_digest, \
+            f"Digest of the partition {partition_name} should change. " \
+            f"Original digest: {original_digest}, updated digest: {updated_digest}"
+        all_digests.add(original_digest)
+        all_digests.add(updated_digest)
+
+        new_content = content.replace(original_digest, updated_digest)
+        assert len(new_content) == len(content), \
+            "Length of new_content and content must be the same."
+        assert new_content != content, \
+            f"original digest of the partition {partition_name} not found. " \
+            f"Original descriptor: {original_descriptor}"
+        content = new_content
+
+    assert len(all_digests) == len(partition_names) * 2, \
+        f"There should be {len(partition_names) * 2} different digests for the original and " \
+        f"updated descriptors for initrd. Original descriptors: {original_descriptors}, " \
+        f"updated descriptors: {updated_descriptors}"
+
+    with open(files['rialto'], "wb") as file:
+        file.write(content)
+
+def find_hash_descriptor_by_partition_name(descriptors, partition_name):
+    """Find the hash descriptor of the partition in the descriptors."""
+    for descriptor_type, descriptor in descriptors:
+        if descriptor_type == 'Hash descriptor' and descriptor['Partition Name'] == partition_name:
+            return descriptor
+    assert False, f'Failed to find hash descriptor for partition {partition_name}'
 
 def VerifyVirtApex(args):
     key = args.key