[release] Replace initrds in rialto for GKI kernels during resigning

Bug: 326363997
Test: atest --host sign_virt_apex_test
Change-Id: I49fac606edbbe0845efd93576d7179c7a08317e2
diff --git a/apex/sign_virt_apex.py b/apex/sign_virt_apex.py
index fbbd152..7c59b54 100644
--- a/apex/sign_virt_apex.py
+++ b/apex/sign_virt_apex.py
@@ -609,20 +609,25 @@
             return resign_decompressed_kernel(kernel_file, initrd_normal_file, initrd_debug_file)
 
     _, original_kernel_descriptors = AvbInfo(args, files['kernel'])
-    resign_kernel_task = resign_kernel('kernel', 'initrd_normal.img', 'initrd_debuggable.img')
+    resign_kernel_tasks = [resign_kernel('kernel', 'initrd_normal.img', 'initrd_debuggable.img')]
+    original_kernels = {"kernel" : original_kernel_descriptors}
 
     for ver in gki_versions:
         if f'gki-{ver}_kernel' in files:
-            resign_kernel(
-                f'gki-{ver}_kernel',
+            kernel_name = f'gki-{ver}_kernel'
+            _, original_kernel_descriptors = AvbInfo(args, files[kernel_name])
+            task = resign_kernel(
+                kernel_name,
                 f'gki-{ver}_initrd_normal.img',
                 f'gki-{ver}_initrd_debuggable.img')
+            resign_kernel_tasks.append(task)
+            original_kernels[kernel_name] = original_kernel_descriptors
 
     # Re-sign rialto if it exists. Rialto only exists in arm64 environment.
     if os.path.exists(files['rialto']):
         update_initrd_digests_task = Async(
-            update_initrd_digests_in_rialto, original_kernel_descriptors, args,
-            files, wait=[resign_kernel_task])
+            update_initrd_digests_of_kernels_in_rialto, original_kernels, args, files,
+            wait=resign_kernel_tasks)
         Async(resign_rialto, args, key, files['rialto'], wait=[update_initrd_digests_task])
 
 def resign_rialto(args, key, rialto_path):
@@ -656,18 +661,7 @@
         f"Value of '{key}' should change for '{context}'" \
         f"Original value: {original[key]}, updated value: {updated[key]}"
 
-def update_initrd_digests_in_rialto(original_descriptors, args, files):
-    _, updated_descriptors = AvbInfo(args, files['kernel'])
-
-    original_digests = extract_hash_descriptors(
-        original_descriptors, lambda x: binascii.unhexlify(x['Digest']))
-    updated_digests = extract_hash_descriptors(
-        updated_descriptors, lambda x: binascii.unhexlify(x['Digest']))
-    assert original_digests.pop("boot") == updated_digests.pop("boot"), \
-        "Hash descriptor of boot should not change for kernel. " \
-        f"Original descriptors: {original_descriptors}, " \
-        f"updated descriptors: {updated_descriptors}"
-
+def update_initrd_digests_of_kernels_in_rialto(original_kernels, args, files):
     # Update the hashes of initrd_normal and initrd_debug in rialto if the
     # bootconfigs in them are updated.
     if args.do_not_update_bootconfigs:
@@ -676,6 +670,26 @@
     with open(files['rialto'], "rb") as file:
         content = file.read()
 
+    for kernel_name, descriptors in original_kernels.items():
+        content = update_initrd_digests_in_rialto(
+            descriptors, args, files, kernel_name, content)
+
+    with open(files['rialto'], "wb") as file:
+        file.write(content)
+
+def update_initrd_digests_in_rialto(
+        original_descriptors, args, files, kernel_name, content):
+    _, updated_descriptors = AvbInfo(args, files[kernel_name])
+
+    original_digests = extract_hash_descriptors(
+        original_descriptors, lambda x: binascii.unhexlify(x['Digest']))
+    updated_digests = extract_hash_descriptors(
+        updated_descriptors, lambda x: binascii.unhexlify(x['Digest']))
+    assert original_digests.pop("boot") == updated_digests.pop("boot"), \
+        "Hash descriptor of boot should not change for " + kernel_name + \
+        f"\nOriginal descriptors: {original_descriptors}, " \
+        f"\nUpdated descriptors: {updated_descriptors}"
+
     # Check that the original and updated digests are different before updating rialto.
     partition_names = {'initrd_normal', 'initrd_debug'}
     assert set(original_digests.keys()) == set(updated_digests.keys()) == partition_names, \
@@ -699,8 +713,7 @@
             f"original digest of the partition {partition_name} not found."
         content = new_content
 
-    with open(files['rialto'], "wb") as file:
-        file.write(content)
+    return content
 
 def extract_hash_descriptors(descriptors, f=lambda x: x):
     return {desc["Partition Name"]: f(desc) for desc in