Propagate salt when resigning

Bug: 319235308
Test: atest --host sign_virt_apex_test
Test: sign_virt_apex <private key> <virt_apex dir>
Test: avbtool info_image --image <artifacts before resigning>
Test: avbtool info_image --image <artifacts after resigning>

Change-Id: I0922ce2f8421a0f51f63e403d008a2f70125e232
diff --git a/apex/sign_virt_apex.py b/apex/sign_virt_apex.py
index 74bccba..5758214 100644
--- a/apex/sign_virt_apex.py
+++ b/apex/sign_virt_apex.py
@@ -228,27 +228,40 @@
     assert original_image_info['Authentication Block'] == resigned_image_info['Authentication Block'], f'authentication block size mismatch: {image_path}'
     assert original_image_info['Auxiliary Block'] == resigned_image_info['Auxiliary Block'], f'auxiliary block size mismatch: {image_path}'
 
-def AddHashFooter(args, key, image_path, partition_name, additional_descriptors=None):
+def AddHashFooter(args, key, image_path, additional_images=None):
     if os.path.basename(image_path) in args.key_overrides:
         key = args.key_overrides[os.path.basename(image_path)]
     info, descriptors = AvbInfo(args, image_path)
     if info:
+        # Extract hash descriptor of original image.
+        descs = {desc['Partition Name']:desc for desc in LookUp(descriptors, 'Hash descriptor')}
+        if additional_images:
+            for additional_image in additional_images:
+                _, additional_desc = AvbInfo(args, additional_image)
+                del descs[LookUp(additional_desc, 'Hash descriptor')[0]['Partition Name']]
+        assert len(descs) == 1, \
+            f'multiple hash descriptors except additional descriptors exist in {image_path}'
+        original_image_descriptor = list(descs.values())[0]
+
         image_size = ReadBytesSize(info['Image size'])
         algorithm = info['Algorithm']
+        original_image_partition_name = original_image_descriptor['Partition Name']
+        original_image_salt = original_image_descriptor['Salt']
         partition_size = str(image_size)
 
         cmd = ['avbtool', 'add_hash_footer',
                '--key', key,
                '--algorithm', algorithm,
-               '--partition_name', partition_name,
+               '--partition_name', original_image_partition_name,
+               '--salt', original_image_salt,
                '--partition_size', partition_size,
                '--image', image_path]
         AppendPropArgument(cmd, descriptors)
         if args.signing_args:
             cmd.extend(shlex.split(args.signing_args))
-        if additional_descriptors:
-            for image in additional_descriptors:
-                cmd.extend(['--include_descriptors_from_image', image])
+        if additional_images:
+            for additional_image in additional_images:
+                cmd.extend(['--include_descriptors_from_image', additional_image])
 
         if 'Rollback Index' in info:
             cmd.extend(['--rollback_index', info['Rollback Index']])
@@ -266,6 +279,7 @@
         algorithm = info['Algorithm']
         partition_name = descriptor['Partition Name']
         hash_algorithm = descriptor['Hash Algorithm']
+        salt = descriptor['Salt']
         partition_size = str(image_size)
         cmd = ['avbtool', 'add_hashtree_footer',
                '--key', key,
@@ -274,6 +288,7 @@
                '--partition_size', partition_size,
                '--do_not_generate_fec',
                '--hash_algorithm', hash_algorithm,
+               '--salt', salt,
                '--image', image_path]
         AppendPropArgument(cmd, descriptors)
         if args.signing_args:
@@ -426,10 +441,11 @@
         RunCommand(args, cmd)
 
 
-def GenVbmetaImage(args, image, output, partition_name):
+def GenVbmetaImage(args, image, output, partition_name, salt):
     cmd = ['avbtool', 'add_hash_footer', '--dynamic_partition_size',
            '--do_not_append_vbmeta_image',
            '--partition_name', partition_name,
+           '--salt', salt,
            '--image', image,
            '--output_vbmeta_image', output]
     RunCommand(args, cmd)
@@ -513,16 +529,18 @@
         initrd_normal_file = files[initrd_normal]
         initrd_debug_file = files[initrd_debug]
 
+        _, kernel_image_descriptors = AvbInfo(args, kernel_file)
+        salts = {desc['Partition Name']:desc['Salt'] for desc in LookUp(kernel_image_descriptors, 'Hash descriptor')}
         initrd_normal_hashdesc = tempfile.NamedTemporaryFile(delete=False).name
         initrd_debug_hashdesc = tempfile.NamedTemporaryFile(delete=False).name
         initrd_n_f = Async(GenVbmetaImage, args, initrd_normal_file,
-                           initrd_normal_hashdesc, "initrd_normal",
+                           initrd_normal_hashdesc, "initrd_normal", salts["initrd_normal"],
                            wait=[vbmeta_bc_f] if vbmeta_bc_f is not None else [])
         initrd_d_f = Async(GenVbmetaImage, args, initrd_debug_file,
-                           initrd_debug_hashdesc, "initrd_debug",
+                           initrd_debug_hashdesc, "initrd_debug", salts["initrd_debug"],
                            wait=[vbmeta_bc_f] if vbmeta_bc_f is not None else [])
-        Async(AddHashFooter, args, key, kernel_file, partition_name="boot",
-              additional_descriptors=[
+        Async(AddHashFooter, args, key, kernel_file,
+              additional_images=[
                   initrd_normal_hashdesc, initrd_debug_hashdesc],
               wait=[initrd_n_f, initrd_d_f])
 
@@ -537,7 +555,7 @@
 
     # Re-sign rialto if it exists. Rialto only exists in arm64 environment.
     if os.path.exists(files['rialto']):
-        Async(AddHashFooter, args, key, files['rialto'], partition_name='boot')
+        Async(AddHashFooter, args, key, files['rialto'])
 
 
 def VerifyVirtApex(args):