Merge "[release] Add more check for rialto resigning" into main
diff --git a/apex/sign_virt_apex.py b/apex/sign_virt_apex.py
index 0bcbac7..0b6137b 100644
--- a/apex/sign_virt_apex.py
+++ b/apex/sign_virt_apex.py
@@ -208,74 +208,111 @@
return info, descriptors
-# Look up a list of (key, value) with a key. Returns the list of value(s) with the matching key.
-# The order of those values is maintained.
-def LookUp(pairs, key):
+def find_all_values_by_key(pairs, key):
+ """Find all the values of the key in the pairs."""
return [v for (k, v) in pairs if k == key]
# Extract properties from the descriptors of original vbmeta image,
# append to command as parameter.
def AppendPropArgument(cmd, descriptors):
- for prop in LookUp(descriptors, 'Prop'):
+ for prop in find_all_values_by_key(descriptors, 'Prop'):
cmd.append('--prop')
result = re.match(r"(.+) -> '(.+)'", prop)
cmd.append(result.group(1) + ":" + result.group(2))
-def check_no_size_change_on_resigned_image(image_path, original_image_info, resigned_image_info):
- assert original_image_info is not None, f'no avbinfo on original image: {image_path}'
- assert resigned_image_info is not None, f'no avbinfo on resigned image: {image_path}'
- assert_same_value(original_image_info, resigned_image_info, "Header Block", image_path)
- assert_same_value(original_image_info, resigned_image_info, "Authentication Block", image_path)
- assert_same_value(original_image_info, resigned_image_info, "Auxiliary Block", image_path)
+def check_resigned_image_avb_info(image_path, original_info, original_descriptors, args):
+ updated_info, updated_descriptors = AvbInfo(args, image_path)
+ assert original_info is not None, f'no avbinfo on original image: {image_path}'
+ assert updated_info is not None, f'no avbinfo on resigned image: {image_path}'
+ assert_different_value(original_info, updated_info, "Public key (sha1)", image_path)
+ updated_public_key = updated_info.pop("Public key (sha1)")
+ if not hasattr(check_resigned_image_avb_info, "new_public_key"):
+ check_resigned_image_avb_info.new_public_key = updated_public_key
+ else:
+ assert check_resigned_image_avb_info.new_public_key == updated_public_key, \
+ "All images should be resigned with the same public key. Expected public key (sha1):" \
+ f" {check_resigned_image_avb_info.new_public_key}, actual public key (sha1): " \
+ f"{updated_public_key}, Path: {image_path}"
+ original_info.pop("Public key (sha1)")
+ assert original_info == updated_info, \
+ f"Original info and updated info should be the same for {image_path}. " \
+ f"Original info: {original_info}, updated info: {updated_info}"
-def AddHashFooter(args, key, image_path, additional_images=None):
+ # Verify the descriptors of the original and updated images.
+ assert len(original_descriptors) == len(updated_descriptors), \
+ f"Number of descriptors should be the same for {image_path}. " \
+ f"Original descriptors: {original_descriptors}, updated descriptors: {updated_descriptors}"
+ original_prop_descriptors = sorted(find_all_values_by_key(original_descriptors, "Prop"))
+ updated_prop_descriptors = sorted(find_all_values_by_key(updated_descriptors, "Prop"))
+ assert original_prop_descriptors == updated_prop_descriptors, \
+ f"Prop descriptors should be the same for {image_path}. " \
+ f"Original prop descriptors: {original_prop_descriptors}, " \
+ f"updated prop descriptors: {updated_prop_descriptors}"
+
+ # Remove digest from hash descriptors before comparing, since some digests should change.
+ original_hash_descriptors = extract_hash_descriptors(original_descriptors, drop_digest)
+ updated_hash_descriptors = extract_hash_descriptors(updated_descriptors, drop_digest)
+ assert original_hash_descriptors == updated_hash_descriptors, \
+ f"Hash descriptors' parameters should be the same for {image_path}. " \
+ f"Original hash descriptors: {original_hash_descriptors}, " \
+ f"updated hash descriptors: {updated_hash_descriptors}"
+
+def drop_digest(descriptor):
+ return {k: v for k, v in descriptor.items() if k != "Digest"}
+
+def AddHashFooter(args, key, image_path, additional_images=()):
if os.path.basename(image_path) in args.key_overrides:
key = args.key_overrides[os.path.basename(image_path)]
info, descriptors = AvbInfo(args, image_path)
- if info:
- # Extract hash descriptor of original image.
- descs = {desc['Partition Name']:desc for desc in LookUp(descriptors, 'Hash descriptor')}
- if additional_images:
- for additional_image in additional_images:
- _, additional_desc = AvbInfo(args, additional_image)
- del descs[LookUp(additional_desc, 'Hash descriptor')[0]['Partition Name']]
- assert len(descs) == 1, \
- f'multiple hash descriptors except additional descriptors exist in {image_path}'
- original_image_descriptor = list(descs.values())[0]
+ assert info is not None, f'no avbinfo: {image_path}'
- image_size = ReadBytesSize(info['Image size'])
- algorithm = info['Algorithm']
- original_image_partition_name = original_image_descriptor['Partition Name']
- original_image_salt = original_image_descriptor['Salt']
- partition_size = str(image_size)
+ # Extract hash descriptor of original image.
+ hash_descriptors_original = extract_hash_descriptors(descriptors, drop_digest)
+ for additional_image in additional_images:
+ _, additional_desc = AvbInfo(args, additional_image)
+ hash_descriptors = extract_hash_descriptors(additional_desc, drop_digest)
+ for k, v in hash_descriptors.items():
+ assert v == hash_descriptors_original[k], \
+ f"Hash descriptor of {k} in {additional_image} and {image_path} should be " \
+ f"the same. {additional_image}: {v}, {image_path}: {hash_descriptors_original[k]}"
+ del hash_descriptors_original[k]
+ assert len(hash_descriptors_original) == 1, \
+ f"Only one hash descriptor is expected for {image_path} after removing " \
+ f"additional images. Hash descriptors: {hash_descriptors_original}"
+ [(original_image_partition_name, original_image_descriptor)] = hash_descriptors_original.items()
+ assert info["Original image size"] == original_image_descriptor["Image Size"], \
+ f"Original image size should be the same as the image size in the hash descriptor " \
+ f"for {image_path}. Original image size: {info['Original image size']}, " \
+ f"image size in the hash descriptor: {original_image_descriptor['Image Size']}"
- cmd = ['avbtool', 'add_hash_footer',
- '--key', key,
- '--algorithm', algorithm,
- '--partition_name', original_image_partition_name,
- '--salt', original_image_salt,
- '--partition_size', partition_size,
- '--image', image_path]
- AppendPropArgument(cmd, descriptors)
- if args.signing_args:
- cmd.extend(shlex.split(args.signing_args))
- if additional_images:
- for additional_image in additional_images:
- cmd.extend(['--include_descriptors_from_image', additional_image])
+ partition_size = str(ReadBytesSize(info['Image size']))
+ algorithm = info['Algorithm']
+ original_image_salt = original_image_descriptor['Salt']
- if 'Rollback Index' in info:
- cmd.extend(['--rollback_index', info['Rollback Index']])
- RunCommand(args, cmd)
- resigned_info, _ = AvbInfo(args, image_path)
- check_no_size_change_on_resigned_image(image_path, info, resigned_info)
+ cmd = ['avbtool', 'add_hash_footer',
+ '--key', key,
+ '--algorithm', algorithm,
+ '--partition_name', original_image_partition_name,
+ '--salt', original_image_salt,
+ '--partition_size', partition_size,
+ '--image', image_path]
+ AppendPropArgument(cmd, descriptors)
+ if args.signing_args:
+ cmd.extend(shlex.split(args.signing_args))
+ for additional_image in additional_images:
+ cmd.extend(['--include_descriptors_from_image', additional_image])
+ cmd.extend(['--rollback_index', info['Rollback Index']])
+
+ RunCommand(args, cmd)
+ check_resigned_image_avb_info(image_path, info, descriptors, args)
def AddHashTreeFooter(args, key, image_path):
if os.path.basename(image_path) in args.key_overrides:
key = args.key_overrides[os.path.basename(image_path)]
info, descriptors = AvbInfo(args, image_path)
if info:
- descriptor = LookUp(descriptors, 'Hashtree descriptor')[0]
+ descriptor = find_all_values_by_key(descriptors, 'Hashtree descriptor')[0]
image_size = ReadBytesSize(info['Image size'])
algorithm = info['Algorithm']
partition_name = descriptor['Partition Name']
@@ -295,8 +332,7 @@
if args.signing_args:
cmd.extend(shlex.split(args.signing_args))
RunCommand(args, cmd)
- resigned_info, _ = AvbInfo(args, image_path)
- check_no_size_change_on_resigned_image(image_path, info, resigned_info)
+ check_resigned_image_avb_info(image_path, info, descriptors, args)
def UpdateVbmetaBootconfig(args, initrds, vbmeta_img):
@@ -412,8 +448,7 @@
cmd.extend(shlex.split(args.signing_args))
RunCommand(args, cmd)
- resigned_info, _ = AvbInfo(args, vbmeta_img)
- check_no_size_change_on_resigned_image(vbmeta_img, info, resigned_info)
+ check_resigned_image_avb_info(vbmeta_img, info, descriptors, args)
# libavb expects to be able to read the maximum vbmeta size, so we must provide a partition
# which matches this or the read will fail.
with open(vbmeta_img, 'a', encoding='utf8') as f:
@@ -531,7 +566,8 @@
initrd_debug_file = files[initrd_debug]
_, kernel_image_descriptors = AvbInfo(args, kernel_file)
- salts = {desc['Partition Name']:desc['Salt'] for desc in LookUp(kernel_image_descriptors, 'Hash descriptor')}
+ salts = extract_hash_descriptors(
+ kernel_image_descriptors, lambda descriptor: descriptor['Salt'])
initrd_normal_hashdesc = tempfile.NamedTemporaryFile(delete=False).name
initrd_debug_hashdesc = tempfile.NamedTemporaryFile(delete=False).name
initrd_n_f = Async(GenVbmetaImage, args, initrd_normal_file,
@@ -556,46 +592,53 @@
# Re-sign rialto if it exists. Rialto only exists in arm64 environment.
if os.path.exists(files['rialto']):
- update_kernel_hashes_task = Async(
- update_initrd_hashes_in_rialto, original_kernel_descriptors, args,
+ update_initrd_digests_task = Async(
+ update_initrd_digests_in_rialto, original_kernel_descriptors, args,
files, wait=[resign_kernel_task])
- Async(resign_rialto, args, key, files['rialto'], wait=[update_kernel_hashes_task])
+ Async(resign_rialto, args, key, files['rialto'], wait=[update_initrd_digests_task])
def resign_rialto(args, key, rialto_path):
- original_info, original_descriptors = AvbInfo(args, rialto_path)
+ _, original_descriptors = AvbInfo(args, rialto_path)
AddHashFooter(args, key, rialto_path)
# Verify the new AVB footer.
updated_info, updated_descriptors = AvbInfo(args, rialto_path)
- assert_same_value(original_info, updated_info, "Original image size", "rialto")
- original_descriptor = find_hash_descriptor_by_partition_name(original_descriptors, 'boot')
- updated_descriptor = find_hash_descriptor_by_partition_name(updated_descriptors, 'boot')
+ assert len(updated_descriptors) == 2, \
+ f"There should be two descriptors for rialto. Updated descriptors: {updated_descriptors}"
+ updated_prop = find_all_values_by_key(updated_descriptors, "Prop")
+ assert len(updated_prop) == 1, "There should be only one Prop descriptor for rialto. " \
+ f"Updated descriptors: {updated_descriptors}"
+ assert updated_info["Rollback Index"] != "0", "Rollback index should not be zero for rialto."
+
+ # Verify the only hash descriptor of rialto.
+ updated_hash_descriptors = extract_hash_descriptors(updated_descriptors)
+ assert len(updated_hash_descriptors) == 1, \
+ f"There should be only one hash descriptor for rialto. " \
+ f"Updated hash descriptors: {updated_hash_descriptors}"
# Since salt is not updated, the change of digest reflects the change of content of rialto
# kernel.
- assert_same_value(original_descriptor, updated_descriptor, "Salt", "rialto_hash_descriptor")
if not args.do_not_update_bootconfigs:
+ [(_, original_descriptor)] = extract_hash_descriptors(original_descriptors).items()
+ [(_, updated_descriptor)] = updated_hash_descriptors.items()
assert_different_value(original_descriptor, updated_descriptor, "Digest",
"rialto_hash_descriptor")
-def assert_same_value(original, updated, key, context):
- assert original[key] == updated[key], \
- f"Value of '{key}' should not change for '{context}'" \
- f"Original value: {original[key]}, updated value: {updated[key]}"
-
def assert_different_value(original, updated, key, context):
assert original[key] != updated[key], \
f"Value of '{key}' should change for '{context}'" \
f"Original value: {original[key]}, updated value: {updated[key]}"
-def update_initrd_hashes_in_rialto(original_descriptors, args, files):
+def update_initrd_digests_in_rialto(original_descriptors, args, files):
_, updated_descriptors = AvbInfo(args, files['kernel'])
- original_kernel_descriptor = find_hash_descriptor_by_partition_name(
- original_descriptors, 'boot')
- updated_kernel_descriptor = find_hash_descriptor_by_partition_name(
- updated_descriptors, 'boot')
- assert_same_value(original_kernel_descriptor, updated_kernel_descriptor,
- "Digest", "microdroid_kernel_hash_descriptor")
+ original_digests = extract_hash_descriptors(
+ original_descriptors, lambda x: binascii.unhexlify(x['Digest']))
+ updated_digests = extract_hash_descriptors(
+ updated_descriptors, lambda x: binascii.unhexlify(x['Digest']))
+ assert original_digests.pop("boot") == updated_digests.pop("boot"), \
+ "Hash descriptor of boot should not change for kernel. " \
+ f"Original descriptors: {original_descriptors}, " \
+ f"updated descriptors: {updated_descriptors}"
# Update the hashes of initrd_normal and initrd_debug in rialto if the
# bootconfigs in them are updated.
@@ -605,47 +648,35 @@
with open(files['rialto'], "rb") as file:
content = file.read()
- partition_names = ['initrd_normal', 'initrd_debug']
- all_digests = set()
- for partition_name in partition_names:
- original_descriptor = find_hash_descriptor_by_partition_name(
- original_descriptors, partition_name)
- update_descriptor = find_hash_descriptor_by_partition_name(
- updated_descriptors, partition_name)
+ # Check that the original and updated digests are different before updating rialto.
+ partition_names = {'initrd_normal', 'initrd_debug'}
+ assert set(original_digests.keys()) == set(updated_digests.keys()) == partition_names, \
+ f"Original digests' partitions should be {partition_names}. " \
+ f"Original digests: {original_digests}. Updated digests: {updated_digests}"
+ assert set(original_digests.values()).isdisjoint(updated_digests.values()), \
+ "Digests of initrd_normal and initrd_debug should change. " \
+ f"Original descriptors: {original_descriptors}, " \
+ f"updated descriptors: {updated_descriptors}"
- original_digest = binascii.unhexlify(original_descriptor['Digest'])
- updated_digest = binascii.unhexlify(update_descriptor['Digest'])
+ for partition_name, original_digest in original_digests.items():
+ updated_digest = updated_digests[partition_name]
assert len(original_digest) == len(updated_digest), \
f"Length of original_digest and updated_digest must be the same for {partition_name}." \
f" Original digest: {original_digest}, updated digest: {updated_digest}"
- assert original_digest != updated_digest, \
- f"Digest of the partition {partition_name} should change. " \
- f"Original digest: {original_digest}, updated digest: {updated_digest}"
- all_digests.add(original_digest)
- all_digests.add(updated_digest)
new_content = content.replace(original_digest, updated_digest)
assert len(new_content) == len(content), \
"Length of new_content and content must be the same."
assert new_content != content, \
- f"original digest of the partition {partition_name} not found. " \
- f"Original descriptor: {original_descriptor}"
+ f"original digest of the partition {partition_name} not found."
content = new_content
- assert len(all_digests) == len(partition_names) * 2, \
- f"There should be {len(partition_names) * 2} different digests for the original and " \
- f"updated descriptors for initrd. Original descriptors: {original_descriptors}, " \
- f"updated descriptors: {updated_descriptors}"
-
with open(files['rialto'], "wb") as file:
file.write(content)
-def find_hash_descriptor_by_partition_name(descriptors, partition_name):
- """Find the hash descriptor of the partition in the descriptors."""
- for descriptor_type, descriptor in descriptors:
- if descriptor_type == 'Hash descriptor' and descriptor['Partition Name'] == partition_name:
- return descriptor
- assert False, f'Failed to find hash descriptor for partition {partition_name}'
+def extract_hash_descriptors(descriptors, f=lambda x: x):
+ return {desc["Partition Name"]: f(desc) for desc in
+ find_all_values_by_key(descriptors, "Hash descriptor")}
def VerifyVirtApex(args):
key = args.key