Adding compressor class
Adding in compressor class to clean up code for cow_compress.cpp.
Since we are making some api changes (to zstd) that are unique to each
compression methods, these should be implementation details should be
hidden to the parent class
Test: m libsnapshot
Change-Id: I9194e2c615aefed078458f525382253228bc1b69
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/cow_compress.h b/fs_mgr/libsnapshot/include/libsnapshot/cow_compress.h
new file mode 100644
index 0000000..8add041
--- /dev/null
+++ b/fs_mgr/libsnapshot/include/libsnapshot/cow_compress.h
@@ -0,0 +1,48 @@
+//
+// Copyright (C) 2023 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+#pragma once
+
+#include <memory>
+#include <string_view>
+#include "libsnapshot/cow_format.h"
+
+namespace android {
+namespace snapshot {
+
+class ICompressor {
+ public:
+ explicit ICompressor(uint32_t compression_level) : compression_level_(compression_level) {}
+
+ virtual ~ICompressor() {}
+ // Factory methods for compression methods.
+ static std::unique_ptr<ICompressor> Gz(uint32_t compression_level);
+ static std::unique_ptr<ICompressor> Brotli(uint32_t compression_level);
+ static std::unique_ptr<ICompressor> Lz4(uint32_t compression_level);
+ static std::unique_ptr<ICompressor> Zstd(uint32_t compression_level);
+
+ static std::unique_ptr<ICompressor> Create(CowCompression compression);
+
+ uint32_t GetCompressionLevel() const { return compression_level_; }
+
+ [[nodiscard]] virtual std::basic_string<uint8_t> Compress(const void* data,
+ size_t length) const = 0;
+
+ private:
+ uint32_t compression_level_;
+};
+} // namespace snapshot
+} // namespace android
\ No newline at end of file
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/cow_writer.h b/fs_mgr/libsnapshot/include/libsnapshot/cow_writer.h
index 74b8bb8..3016e93 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/cow_writer.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/cow_writer.h
@@ -14,6 +14,8 @@
#pragma once
+#include <libsnapshot/cow_compress.h>
+
#include <stdint.h>
#include <condition_variable>
@@ -107,16 +109,14 @@
class CompressWorker {
public:
- CompressWorker(CowCompression compression, uint32_t block_size);
+ CompressWorker(std::unique_ptr<ICompressor>&& compressor, uint32_t block_size);
bool RunThread();
void EnqueueCompressBlocks(const void* buffer, size_t num_blocks);
bool GetCompressedBuffers(std::vector<std::basic_string<uint8_t>>* compressed_buf);
void Finalize();
static uint32_t GetDefaultCompressionLevel(CowCompressionAlgorithm compression);
- static std::basic_string<uint8_t> Compress(CowCompression compression, const void* data,
- size_t length);
- static bool CompressBlocks(CowCompression compression, size_t block_size, const void* buffer,
+ static bool CompressBlocks(ICompressor* compressor, size_t block_size, const void* buffer,
size_t num_blocks,
std::vector<std::basic_string<uint8_t>>* compressed_data);
@@ -128,7 +128,7 @@
std::vector<std::basic_string<uint8_t>> compressed_data;
};
- CowCompression compression_;
+ std::unique_ptr<ICompressor> compressor_;
uint32_t block_size_;
std::queue<CompressWork> work_queue_;
diff --git a/fs_mgr/libsnapshot/libsnapshot_cow/cow_compress.cpp b/fs_mgr/libsnapshot/libsnapshot_cow/cow_compress.cpp
index 96d6016..d023f19 100644
--- a/fs_mgr/libsnapshot/libsnapshot_cow/cow_compress.cpp
+++ b/fs_mgr/libsnapshot/libsnapshot_cow/cow_compress.cpp
@@ -22,8 +22,11 @@
#include <android-base/file.h>
#include <android-base/logging.h>
+#include <android-base/parseint.h>
+#include <android-base/strings.h>
#include <android-base/unique_fd.h>
#include <brotli/encode.h>
+#include <libsnapshot/cow_compress.h>
#include <libsnapshot/cow_format.h>
#include <libsnapshot/cow_reader.h>
#include <libsnapshot/cow_writer.h>
@@ -51,6 +54,22 @@
}
}
+std::unique_ptr<ICompressor> ICompressor::Create(CowCompression compression) {
+ switch (compression.algorithm) {
+ case kCowCompressLz4:
+ return ICompressor::Lz4(compression.compression_level);
+ case kCowCompressBrotli:
+ return ICompressor::Brotli(compression.compression_level);
+ case kCowCompressGz:
+ return ICompressor::Gz(compression.compression_level);
+ case kCowCompressZstd:
+ return ICompressor::Zstd(compression.compression_level);
+ case kCowCompressNone:
+ return nullptr;
+ }
+ return nullptr;
+}
+
// 1. Default compression level is determined by compression algorithm
// 2. There might be compatibility issues if a value is changed here, as some older versions of
// Android will assume a different compression level, causing cow_size estimation differences that
@@ -77,101 +96,116 @@
return 0;
}
-std::basic_string<uint8_t> CompressWorker::Compress(CowCompression compression, const void* data,
- size_t length) {
- switch (compression.algorithm) {
- case kCowCompressGz: {
- const auto bound = compressBound(length);
- std::basic_string<uint8_t> buffer(bound, '\0');
+class GzCompressor final : public ICompressor {
+ public:
+ GzCompressor(uint32_t compression_level) : ICompressor(compression_level){};
- uLongf dest_len = bound;
- auto rv = compress2(buffer.data(), &dest_len, reinterpret_cast<const Bytef*>(data),
- length, compression.compression_level);
- if (rv != Z_OK) {
- LOG(ERROR) << "compress2 returned: " << rv;
- return {};
- }
- buffer.resize(dest_len);
- return buffer;
- }
- case kCowCompressBrotli: {
- const auto bound = BrotliEncoderMaxCompressedSize(length);
- if (!bound) {
- LOG(ERROR) << "BrotliEncoderMaxCompressedSize returned 0";
- return {};
- }
- std::basic_string<uint8_t> buffer(bound, '\0');
+ std::basic_string<uint8_t> Compress(const void* data, size_t length) const override {
+ const auto bound = compressBound(length);
+ std::basic_string<uint8_t> buffer(bound, '\0');
- size_t encoded_size = bound;
- auto rv = BrotliEncoderCompress(
- compression.compression_level, BROTLI_DEFAULT_WINDOW, BROTLI_DEFAULT_MODE,
- length, reinterpret_cast<const uint8_t*>(data), &encoded_size, buffer.data());
- if (!rv) {
- LOG(ERROR) << "BrotliEncoderCompress failed";
- return {};
- }
- buffer.resize(encoded_size);
- return buffer;
+ uLongf dest_len = bound;
+ auto rv = compress2(buffer.data(), &dest_len, reinterpret_cast<const Bytef*>(data), length,
+ GetCompressionLevel());
+ if (rv != Z_OK) {
+ LOG(ERROR) << "compress2 returned: " << rv;
+ return {};
}
- case kCowCompressLz4: {
- const auto bound = LZ4_compressBound(length);
- if (!bound) {
- LOG(ERROR) << "LZ4_compressBound returned 0";
- return {};
- }
- std::basic_string<uint8_t> buffer(bound, '\0');
+ buffer.resize(dest_len);
+ return buffer;
+ };
+};
- const auto compressed_size = LZ4_compress_default(
- static_cast<const char*>(data), reinterpret_cast<char*>(buffer.data()), length,
- buffer.size());
- if (compressed_size <= 0) {
- LOG(ERROR) << "LZ4_compress_default failed, input size: " << length
- << ", compression bound: " << bound << ", ret: " << compressed_size;
- return {};
- }
- // Don't run compression if the compressed output is larger
- if (compressed_size >= length) {
- buffer.resize(length);
- memcpy(buffer.data(), data, length);
- } else {
- buffer.resize(compressed_size);
- }
- return buffer;
+class Lz4Compressor final : public ICompressor {
+ public:
+ Lz4Compressor(uint32_t compression_level) : ICompressor(compression_level){};
+
+ std::basic_string<uint8_t> Compress(const void* data, size_t length) const override {
+ const auto bound = LZ4_compressBound(length);
+ if (!bound) {
+ LOG(ERROR) << "LZ4_compressBound returned 0";
+ return {};
}
- case kCowCompressZstd: {
- std::basic_string<uint8_t> buffer(ZSTD_compressBound(length), '\0');
- const auto compressed_size = ZSTD_compress(buffer.data(), buffer.size(), data, length,
- compression.compression_level);
- if (compressed_size <= 0) {
- LOG(ERROR) << "ZSTD compression failed " << compressed_size;
- return {};
- }
- // Don't run compression if the compressed output is larger
- if (compressed_size >= length) {
- buffer.resize(length);
- memcpy(buffer.data(), data, length);
- } else {
- buffer.resize(compressed_size);
- }
- return buffer;
+ std::basic_string<uint8_t> buffer(bound, '\0');
+
+ const auto compressed_size =
+ LZ4_compress_default(static_cast<const char*>(data),
+ reinterpret_cast<char*>(buffer.data()), length, buffer.size());
+ if (compressed_size <= 0) {
+ LOG(ERROR) << "LZ4_compress_default failed, input size: " << length
+ << ", compression bound: " << bound << ", ret: " << compressed_size;
+ return {};
}
- default:
- LOG(ERROR) << "unhandled compression type: " << compression.algorithm;
- break;
- }
- return {};
-}
+ // Don't run compression if the compressed output is larger
+ if (compressed_size >= length) {
+ buffer.resize(length);
+ memcpy(buffer.data(), data, length);
+ } else {
+ buffer.resize(compressed_size);
+ }
+ return buffer;
+ };
+};
+
+class BrotliCompressor final : public ICompressor {
+ public:
+ BrotliCompressor(uint32_t compression_level) : ICompressor(compression_level){};
+
+ std::basic_string<uint8_t> Compress(const void* data, size_t length) const override {
+ const auto bound = BrotliEncoderMaxCompressedSize(length);
+ if (!bound) {
+ LOG(ERROR) << "BrotliEncoderMaxCompressedSize returned 0";
+ return {};
+ }
+ std::basic_string<uint8_t> buffer(bound, '\0');
+
+ size_t encoded_size = bound;
+ auto rv = BrotliEncoderCompress(
+ GetCompressionLevel(), BROTLI_DEFAULT_WINDOW, BROTLI_DEFAULT_MODE, length,
+ reinterpret_cast<const uint8_t*>(data), &encoded_size, buffer.data());
+ if (!rv) {
+ LOG(ERROR) << "BrotliEncoderCompress failed";
+ return {};
+ }
+ buffer.resize(encoded_size);
+ return buffer;
+ };
+};
+
+class ZstdCompressor final : public ICompressor {
+ public:
+ ZstdCompressor(uint32_t compression_level) : ICompressor(compression_level){};
+
+ std::basic_string<uint8_t> Compress(const void* data, size_t length) const override {
+ std::basic_string<uint8_t> buffer(ZSTD_compressBound(length), '\0');
+ const auto compressed_size =
+ ZSTD_compress(buffer.data(), buffer.size(), data, length, GetCompressionLevel());
+ if (compressed_size <= 0) {
+ LOG(ERROR) << "ZSTD compression failed " << compressed_size;
+ return {};
+ }
+ // Don't run compression if the compressed output is larger
+ if (compressed_size >= length) {
+ buffer.resize(length);
+ memcpy(buffer.data(), data, length);
+ } else {
+ buffer.resize(compressed_size);
+ }
+ return buffer;
+ };
+};
+
bool CompressWorker::CompressBlocks(const void* buffer, size_t num_blocks,
std::vector<std::basic_string<uint8_t>>* compressed_data) {
- return CompressBlocks(compression_, block_size_, buffer, num_blocks, compressed_data);
+ return CompressBlocks(compressor_.get(), block_size_, buffer, num_blocks, compressed_data);
}
-bool CompressWorker::CompressBlocks(CowCompression compression, size_t block_size,
- const void* buffer, size_t num_blocks,
+bool CompressWorker::CompressBlocks(ICompressor* compressor, size_t block_size, const void* buffer,
+ size_t num_blocks,
std::vector<std::basic_string<uint8_t>>* compressed_data) {
const uint8_t* iter = reinterpret_cast<const uint8_t*>(buffer);
while (num_blocks) {
- auto data = Compress(compression, iter, block_size);
+ auto data = compressor->Compress(iter, block_size);
if (data.empty()) {
PLOG(ERROR) << "CompressBlocks: Compression failed";
return false;
@@ -270,6 +304,22 @@
return true;
}
+std::unique_ptr<ICompressor> ICompressor::Brotli(uint32_t compression_level) {
+ return std::make_unique<BrotliCompressor>(compression_level);
+}
+
+std::unique_ptr<ICompressor> ICompressor::Gz(uint32_t compression_level) {
+ return std::make_unique<GzCompressor>(compression_level);
+}
+
+std::unique_ptr<ICompressor> ICompressor::Lz4(uint32_t compression_level) {
+ return std::make_unique<Lz4Compressor>(compression_level);
+}
+
+std::unique_ptr<ICompressor> ICompressor::Zstd(uint32_t compression_level) {
+ return std::make_unique<ZstdCompressor>(compression_level);
+}
+
void CompressWorker::Finalize() {
{
std::unique_lock<std::mutex> lock(lock_);
@@ -278,8 +328,8 @@
cv_.notify_all();
}
-CompressWorker::CompressWorker(CowCompression compression, uint32_t block_size)
- : compression_(compression), block_size_(block_size) {}
+CompressWorker::CompressWorker(std::unique_ptr<ICompressor>&& compressor, uint32_t block_size)
+ : compressor_(std::move(compressor)), block_size_(block_size) {}
} // namespace snapshot
} // namespace android
diff --git a/fs_mgr/libsnapshot/libsnapshot_cow/test_v2.cpp b/fs_mgr/libsnapshot/libsnapshot_cow/test_v2.cpp
index 2258d9f..0ecad9d 100644
--- a/fs_mgr/libsnapshot/libsnapshot_cow/test_v2.cpp
+++ b/fs_mgr/libsnapshot/libsnapshot_cow/test_v2.cpp
@@ -480,7 +480,8 @@
std::string expected = "The quick brown fox jumps over the lazy dog.";
expected.resize(4096, '\0');
- auto result = CompressWorker::Compress(compression, expected.data(), expected.size());
+ std::unique_ptr<ICompressor> compressor = ICompressor::Create(compression);
+ auto result = compressor->Compress(expected.data(), expected.size());
ASSERT_FALSE(result.empty());
HorribleStream<uint8_t> stream(result);
diff --git a/fs_mgr/libsnapshot/libsnapshot_cow/writer_v2.cpp b/fs_mgr/libsnapshot/libsnapshot_cow/writer_v2.cpp
index 6d04c6a..7115821 100644
--- a/fs_mgr/libsnapshot/libsnapshot_cow/writer_v2.cpp
+++ b/fs_mgr/libsnapshot/libsnapshot_cow/writer_v2.cpp
@@ -184,7 +184,8 @@
return;
}
for (int i = 0; i < num_compress_threads_; i++) {
- auto wt = std::make_unique<CompressWorker>(compression_, header_.block_size);
+ std::unique_ptr<ICompressor> compressor = ICompressor::Create(compression_);
+ auto wt = std::make_unique<CompressWorker>(std::move(compressor), header_.block_size);
threads_.emplace_back(std::async(std::launch::async, &CompressWorker::RunThread, wt.get()));
compress_threads_.push_back(std::move(wt));
}
@@ -339,10 +340,12 @@
const uint8_t* iter = reinterpret_cast<const uint8_t*>(data);
compressed_buf_.clear();
if (num_threads <= 1) {
- return CompressWorker::CompressBlocks(compression_, options_.block_size, data, num_blocks,
- &compressed_buf_);
+ if (!compressor_) {
+ compressor_ = ICompressor::Create(compression_);
+ }
+ return CompressWorker::CompressBlocks(compressor_.get(), options_.block_size, data,
+ num_blocks, &compressed_buf_);
}
-
// Submit the blocks per thread. The retrieval of
// compressed buffers has to be done in the same order.
// We should not poll for completed buffers in a different order as the
@@ -412,8 +415,11 @@
buf_iter_++;
return data;
} else {
- auto data =
- CompressWorker::Compress(compression_, iter, header_.block_size);
+ if (!compressor_) {
+ compressor_ = ICompressor::Create(compression_);
+ }
+
+ auto data = compressor_->Compress(iter, header_.block_size);
return data;
}
}();
diff --git a/fs_mgr/libsnapshot/libsnapshot_cow/writer_v2.h b/fs_mgr/libsnapshot/libsnapshot_cow/writer_v2.h
index 3f357e0..131a068 100644
--- a/fs_mgr/libsnapshot/libsnapshot_cow/writer_v2.h
+++ b/fs_mgr/libsnapshot/libsnapshot_cow/writer_v2.h
@@ -65,6 +65,9 @@
private:
CowFooter footer_{};
CowCompression compression_;
+ // in the case that we are using one thread for compression, we can store and re-use the same
+ // compressor
+ std::unique_ptr<ICompressor> compressor_;
uint64_t current_op_pos_ = 0;
uint64_t next_op_pos_ = 0;
uint64_t next_data_pos_ = 0;