idsig: less copies when creating the merkle tree
This change improves the merkle tree generation routine so that we don't
do unnecessary data copies. Previously, hashes for level N is written
to a temporary store and then copied into the tree. Even worse, the
hashes written to the tree is copied into another buffer when they are
used as the inputs for the next level.
With this CL, the hashes are directly written to and read from the tree.
This is done by having two (non-overlapping) slices on the hash tree.
Bug: N/A
Test: cargo test
Change-Id: I34be81ece6941eba78980c8bc4697ed5d523ed53
diff --git a/idsig/src/hashtree.rs b/idsig/src/hashtree.rs
index 79ba9d7..a4727a9 100644
--- a/idsig/src/hashtree.rs
+++ b/idsig/src/hashtree.rs
@@ -14,8 +14,8 @@
* limitations under the License.
*/
-use ring::digest::{self, Algorithm};
-use std::io::{Cursor, Read, Result, Seek, SeekFrom, Write};
+use ring::digest::{self, Algorithm, Digest};
+use std::io::{Cursor, Read, Result, Write};
/// `HashTree` is a merkle tree (and its root hash) that is compatible with fs-verity.
pub struct HashTree {
@@ -39,14 +39,14 @@
let tree = generate_hash_tree(input, input_size, &salt, block_size, algorithm)?;
// Root hash is from the first block of the hash or the input data if there is no hash tree
- // generate which can happen when input data is smaller than block size
+ // generated which can happen when input data is smaller than block size
let root_hash = if tree.is_empty() {
- hash_one_level(input, input_size, &salt, block_size, algorithm)?
+ let mut data = Vec::new();
+ input.read_to_end(&mut data)?;
+ hash_one_block(&data, &salt, block_size, algorithm).as_ref().to_vec()
} else {
- let mut ctx = digest::Context::new(algorithm);
- ctx.update(&salt);
- ctx.update(&tree[0..block_size]);
- ctx.finish().as_ref().to_vec()
+ let first_block = &tree[0..block_size];
+ hash_one_block(first_block, &salt, block_size, algorithm).as_ref().to_vec()
};
Ok(HashTree { tree, root_hash })
}
@@ -70,68 +70,65 @@
algorithm: &'static Algorithm,
) -> Result<Vec<u8>> {
let digest_size = algorithm.output_len;
- let (hash_level_offsets, tree_size) =
- calc_hash_level_offsets(input_size, block_size, digest_size);
+ let levels = calc_hash_levels(input_size, block_size, digest_size);
+ let tree_size = levels.iter().map(|r| r.len()).sum();
- let mut hash_tree = Cursor::new(vec![0; tree_size]);
- let mut input_size = input_size;
- for (level, offset) in hash_level_offsets.iter().enumerate() {
- let hashes = if level == 0 {
- hash_one_level(input, input_size, salt, block_size, algorithm)?
+ // The contiguous memory that holds the entire merkle tree
+ let mut hash_tree = vec![0; tree_size];
+
+ for (n, cur) in levels.iter().enumerate() {
+ if n == 0 {
+ // Level 0: the (zero-padded) input stream is hashed into level 0
+ let pad_size = round_to_multiple(input_size, block_size) - input_size;
+ let mut input = input.chain(Cursor::new(vec![0; pad_size]));
+ let mut level0 = Cursor::new(&mut hash_tree[cur.start..cur.end]);
+
+ let mut a_block = vec![0; block_size];
+ let mut num_blocks = (input_size + block_size - 1) / block_size;
+ while num_blocks > 0 {
+ input.read_exact(&mut a_block)?;
+ let h = hash_one_block(&a_block, salt, block_size, algorithm);
+ level0.write_all(h.as_ref()).unwrap();
+ num_blocks -= 1;
+ }
} else {
- // For the intermediate levels, input is the output from the previous level
- hash_tree.seek(SeekFrom::Start(hash_level_offsets[level - 1] as u64)).unwrap();
- hash_one_level(&mut hash_tree, input_size, salt, block_size, algorithm)?
- };
- hash_tree.seek(SeekFrom::Start(*offset as u64)).unwrap();
- hash_tree.write_all(hashes.as_ref()).unwrap();
- // Output from this level becomes input for the next level
- input_size = hashes.len();
+ // Intermediate levels: level n - 1 is hashed into level n
+ // Both levels belong to the same `hash_tree`. In order to have a mutable slice for
+ // level n while having a slice for level n - 1, take the mutable slice for both levels
+ // and split it.
+ let prev = &levels[n - 1];
+ let cur_and_prev = &mut hash_tree[cur.start..prev.end];
+ let (cur, prev) = cur_and_prev.split_at_mut(prev.start);
+ let mut cur = Cursor::new(cur);
+ prev.chunks(block_size).for_each(|data| {
+ let h = hash_one_block(data, salt, block_size, algorithm);
+ cur.write_all(h.as_ref()).unwrap();
+ });
+ }
}
- Ok(hash_tree.into_inner())
+ Ok(hash_tree)
}
-/// Calculate hashes for the blocks in `input`. The end of the last block is zero-padded if needed.
-/// Each block is then hashed, producing a stream of hashes for a level.
-fn hash_one_level<R: Read>(
- input: &mut R,
- input_size: usize,
+/// Hash one block of input using the given hash algorithm and the salt. Input might be smaller
+/// than a block, in which case zero is padded.
+fn hash_one_block(
+ input: &[u8],
salt: &[u8],
block_size: usize,
algorithm: &'static Algorithm,
-) -> Result<Vec<u8>> {
- // Input is zero padded when it's not multiple of blocks. Note that `take()` is also needed to
- // not read more than `input_size` from the `input` reader. This is required because `input`
- // can be from the in-memory hashtree. We need to read only the part of hashtree that is for
- // the current level.
- let pad_size = round_to_multiple(input_size, block_size) - input_size;
- let mut input = input.take(input_size as u64).chain(Cursor::new(vec![0; pad_size]));
-
- // Read one block from input, write the hash of it to the output. Repeat that for all input
- // blocks.
- let mut hashes = Cursor::new(Vec::new());
- let mut buf = vec![0; block_size];
- let mut num_blocks = (input_size + block_size - 1) / block_size;
- while num_blocks > 0 {
- input.read_exact(&mut buf)?;
- let mut ctx = digest::Context::new(algorithm);
- ctx.update(salt);
- ctx.update(&buf);
- let hash = ctx.finish();
- hashes.write_all(hash.as_ref())?;
- num_blocks -= 1;
- }
- Ok(hashes.into_inner())
+) -> Digest {
+ let mut ctx = digest::Context::new(algorithm);
+ ctx.update(salt);
+ ctx.update(input);
+ let pad_size = block_size - input.len();
+ ctx.update(&vec![0; pad_size]);
+ ctx.finish()
}
-/// Calculate the size of hashes for each level, and also returns the total size of the hash tree.
-/// This function is needed because hash tree is stored upside down; hashes for level N is stored
-/// "after" hashes for level N + 1.
-fn calc_hash_level_offsets(
- input_size: usize,
- block_size: usize,
- digest_size: usize,
-) -> (Vec<usize>, usize) {
+type Range = std::ops::Range<usize>;
+
+/// Calculate the ranges of hash for each level
+fn calc_hash_levels(input_size: usize, block_size: usize, digest_size: usize) -> Vec<Range> {
// The input is split into multiple blocks and each block is hashed, which becomes the input
// for the next level. Size of a single hash is `digest_size`.
let mut level_sizes = Vec::new();
@@ -145,9 +142,6 @@
let hashes_size = round_to_multiple(num_blocks * digest_size, block_size);
level_sizes.push(hashes_size);
}
- if level_sizes.is_empty() {
- return ([].to_vec(), 0);
- }
// The hash tree is stored upside down. The top level is at offset 0. The second level comes
// next, and so on. Level 0 is located at the end.
@@ -158,18 +152,18 @@
// Level 1 is at offset 1 (because Level 2 is of size 1)
// Level 0 is at offset 4 (because Level 1 is of size 3)
//
- // This is done by accumulating the sizes in reverse order (i.e. from the highest level to the
- // level 1 (not level 0)
- let mut offsets = level_sizes.iter().rev().take(level_sizes.len() - 1).fold(
- vec![0; 1], // offset for the top level
- |mut offsets, size| {
- offsets.push(offsets.last().unwrap() + size);
- offsets
- },
- );
- offsets.reverse(); // reverse the offsets again so that index N is for level N
- let tree_size = level_sizes.iter().sum();
- (offsets, tree_size)
+ // This is done by scanning the sizes in reverse order
+ let mut ranges = level_sizes
+ .iter()
+ .rev()
+ .scan(0, |prev_end, size| {
+ let range = *prev_end..*prev_end + size;
+ *prev_end = range.end;
+ Some(range)
+ })
+ .collect::<Vec<_>>();
+ ranges.reverse(); // reverse again so that index N is for level N
+ ranges
}
/// Round `n` up to the nearest multiple of `unit`