Add streaming support to LZ4 compress routine
In preparation to enable lz4diff streaming mode.
Instead of append data to std::vector, lz4diff APIs
now accept a sink function which we send computed data
to. This way, caller can choose to cache data in memory,
or stream writes to disk as data comes.
Bug: 206729162
Test: th
Change-Id: Ib1aea5c1b730d30a1b4814f8d5dd8ce3a8b43826
diff --git a/lz4diff/lz4diff_compress.cc b/lz4diff/lz4diff_compress.cc
index 930954b..ce9082c 100644
--- a/lz4diff/lz4diff_compress.cc
+++ b/lz4diff/lz4diff_compress.cc
@@ -27,6 +27,98 @@
namespace chromeos_update_engine {
+bool TryCompressBlob(std::string_view blob,
+ const std::vector<CompressedBlock>& block_info,
+ const bool zero_padding_enabled,
+ const CompressionAlgorithm compression_algo,
+ const SinkFunc& sink) {
+ size_t uncompressed_size = 0;
+ for (const auto& block : block_info) {
+ CHECK_EQ(uncompressed_size, block.uncompressed_offset)
+ << "Compressed block info is expected to be sorted.";
+ uncompressed_size += block.uncompressed_length;
+ }
+ auto hc = LZ4_createStreamHC();
+ DEFER {
+ if (hc) {
+ LZ4_freeStreamHC(hc);
+ hc = nullptr;
+ }
+ };
+ size_t compressed_offset = 0;
+ Blob block_buffer;
+ for (const auto& block : block_info) {
+ const auto uncompressed_block =
+ blob.substr(block.uncompressed_offset, block.uncompressed_length);
+ if (!block.IsCompressed()) {
+ TEST_EQ(sink(reinterpret_cast<const uint8_t*>(uncompressed_block.data()),
+ uncompressed_block.size()),
+ uncompressed_block.size());
+ continue;
+ }
+ block_buffer.resize(block.compressed_length);
+ // Execute the increment at end of each loop
+ DEFER {
+ compressed_offset += block.compressed_length;
+ block_buffer.clear();
+ };
+
+ int ret = 0;
+ // LZ4 spec enforces that last op of a compressed block must be an insert op
+ // of at least 5 bytes. Compressors will try to conform to that requirement
+ // if the input size is just right. We don't want that. So always give a
+ // little bit more data.
+ switch (int src_size = uncompressed_size - block.uncompressed_offset;
+ compression_algo.type()) {
+ case CompressionAlgorithm::LZ4HC:
+ ret = LZ4_compress_HC_destSize(
+ hc,
+ uncompressed_block.data(),
+ reinterpret_cast<char*>(block_buffer.data()),
+ &src_size,
+ block.compressed_length,
+ compression_algo.level());
+ break;
+ case CompressionAlgorithm::LZ4:
+ ret =
+ LZ4_compress_destSize(uncompressed_block.data(),
+ reinterpret_cast<char*>(block_buffer.data()),
+ &src_size,
+ block.compressed_length);
+ break;
+ default:
+ LOG(ERROR) << "Unrecognized compression algorithm: "
+ << compression_algo.type();
+ return {};
+ }
+ TEST_GT(ret, 0);
+ const uint64_t bytes_written = ret;
+ // Last block may have trailing zeros
+ TEST_LE(bytes_written, block.compressed_length);
+ if (bytes_written < block.compressed_length) {
+ if (zero_padding_enabled) {
+ const auto padding = block.compressed_length - bytes_written;
+ std::memmove(
+ block_buffer.data() + padding, block_buffer.data(), bytes_written);
+ std::fill(block_buffer.data(), block_buffer.data() + padding, 0);
+
+ } else {
+ std::fill(block_buffer.data() + bytes_written,
+ block_buffer.data() + block.compressed_length,
+ 0);
+ }
+ }
+ TEST_EQ(sink(block_buffer.data(), block_buffer.size()),
+ block_buffer.size());
+ }
+ // Any trailing data will be copied to the output buffer.
+ TEST_EQ(
+ sink(reinterpret_cast<const uint8_t*>(blob.data()) + uncompressed_size,
+ blob.size() - uncompressed_size),
+ blob.size() - uncompressed_size);
+ return true;
+}
+
Blob TryCompressBlob(std::string_view blob,
const std::vector<CompressedBlock>& block_info,
const bool zero_padding_enabled,
@@ -39,79 +131,20 @@
uncompressed_size += block.uncompressed_length;
compressed_size += block.compressed_length;
}
- CHECK_EQ(uncompressed_size, blob.size());
- Blob output(utils::RoundUp(compressed_size, kBlockSize));
- auto hc = LZ4_createStreamHC();
- DEFER {
- if (hc) {
- LZ4_freeStreamHC(hc);
- hc = nullptr;
- }
- };
- size_t compressed_offset = 0;
- for (const auto& block : block_info) {
- // Execute the increment at end of each loop
- DEFER { compressed_offset += block.compressed_length; };
- CHECK_LE(compressed_offset + block.compressed_length, output.size());
-
- if (!block.IsCompressed()) {
- std::memcpy(output.data() + compressed_offset,
- blob.data() + block.uncompressed_offset,
- block.compressed_length);
- continue;
- }
- // LZ4 spec enforces that last op of a compressed block must be an insert op
- // of at least 5 bytes. Compressors will try to conform to that requirement
- // if the input size is just right. We don't want that. So always give a
- // little bit more data.
- int src_size = uncompressed_size - block.uncompressed_offset;
- uint64_t bytes_written = 0;
- switch (compression_algo.type()) {
- case CompressionAlgorithm::LZ4HC:
- bytes_written = LZ4_compress_HC_destSize(
- hc,
- blob.data() + block.uncompressed_offset,
- reinterpret_cast<char*>(output.data()) + compressed_offset,
- &src_size,
- block.compressed_length,
- compression_algo.level());
- break;
- case CompressionAlgorithm::LZ4:
- bytes_written = LZ4_compress_destSize(
- blob.data() + block.uncompressed_offset,
- reinterpret_cast<char*>(output.data()) + compressed_offset,
- &src_size,
- block.compressed_length);
- break;
- default:
- CHECK(false) << "Unrecognized compression algorithm: "
- << compression_algo.type();
- break;
- }
- // Last block may have trailing zeros
- CHECK_LE(bytes_written, block.compressed_length);
- if (bytes_written < block.compressed_length) {
- if (zero_padding_enabled) {
- const auto padding = block.compressed_length - bytes_written;
- // LOG(INFO) << "Padding: " << padding;
- CHECK_LE(compressed_offset + padding + bytes_written, output.size());
- std::memmove(output.data() + compressed_offset + padding,
- output.data() + compressed_offset,
- bytes_written);
- CHECK_LE(compressed_offset + padding, output.size());
- std::fill(output.data() + compressed_offset,
- output.data() + compressed_offset + padding,
- 0);
-
- } else {
- std::fill(output.data() + compressed_offset + bytes_written,
- output.data() + compressed_offset + block.compressed_length,
- 0);
- }
- }
+ TEST_EQ(uncompressed_size, blob.size());
+ Blob output;
+ output.reserve(utils::RoundUp(compressed_size, kBlockSize));
+ if (!TryCompressBlob(blob,
+ block_info,
+ zero_padding_enabled,
+ compression_algo,
+ [&output](const uint8_t* data, size_t size) {
+ output.insert(output.end(), data, data + size);
+ return size;
+ })) {
+ return {};
}
- // Any trailing data will be copied to the output buffer.
- output.insert(output.end(), blob.begin() + uncompressed_size, blob.end());
+
return output;
}
@@ -164,11 +197,6 @@
block.uncompressed_length,
block.uncompressed_length);
if (bytes_decompressed < 0) {
- Blob cluster_hash;
- HashCalculator::RawHashOfBytes(
- cluster.data(), cluster.size(), &cluster_hash);
- Blob blob_hash;
- HashCalculator::RawHashOfBytes(blob.data(), blob.size(), &blob_hash);
LOG(FATAL) << "Failed to decompress, " << bytes_decompressed
<< ", output_cursor = "
<< output.size() - block.uncompressed_length
@@ -177,7 +205,8 @@
<< ", cluster_size = " << block.compressed_length
<< ", dest capacity = " << block.uncompressed_length
<< ", input margin = " << inputmargin << " "
- << HexEncode(cluster_hash) << " " << HexEncode(blob_hash);
+ << HashCalculator::SHA256Digest(cluster) << " "
+ << HashCalculator::SHA256Digest(blob);
return {};
}
compressed_offset += block.compressed_length;
@@ -197,11 +226,6 @@
return output;
}
-[[nodiscard]] std::string_view ToStringView(const Blob& blob) noexcept {
- return std::string_view{reinterpret_cast<const char*>(blob.data()),
- blob.size()};
-}
-
Blob TryDecompressBlob(const Blob& blob,
const std::vector<CompressedBlock>& block_info,
const bool zero_padding_enabled) {
@@ -216,11 +240,6 @@
return out;
}
-[[nodiscard]] std::string_view ToStringView(const void* data,
- size_t size) noexcept {
- return std::string_view(reinterpret_cast<const char*>(data), size);
-}
-
std::ostream& operator<<(std::ostream& out, const CompressedBlockInfo& info) {
out << "BlockInfo { compressed_length: " << info.compressed_length()
<< ", uncompressed_length: " << info.uncompressed_length()