libsnapshot: Refactor cow_reader decompression.
Bug: 162274240
Test: cow_api_test
Change-Id: I12c177f3ebb7bb0550669bd5edbdbbde6f572cfd
diff --git a/fs_mgr/libsnapshot/Android.bp b/fs_mgr/libsnapshot/Android.bp
index d11d3e4..bdf1da6 100644
--- a/fs_mgr/libsnapshot/Android.bp
+++ b/fs_mgr/libsnapshot/Android.bp
@@ -134,6 +134,7 @@
],
export_include_dirs: ["include"],
srcs: [
+ "cow_decompress.cpp",
"cow_reader.cpp",
"cow_writer.cpp",
],
diff --git a/fs_mgr/libsnapshot/cow_decompress.cpp b/fs_mgr/libsnapshot/cow_decompress.cpp
new file mode 100644
index 0000000..f480b85
--- /dev/null
+++ b/fs_mgr/libsnapshot/cow_decompress.cpp
@@ -0,0 +1,211 @@
+//
+// Copyright (C) 2020 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+#include "cow_decompress.h"
+
+#include <utility>
+
+#include <android-base/logging.h>
+#include <zlib.h>
+
+namespace android {
+namespace snapshot {
+
+class NoDecompressor final : public IDecompressor {
+ public:
+ bool Decompress(size_t) override;
+};
+
+bool NoDecompressor::Decompress(size_t) {
+ size_t stream_remaining = stream_->Size();
+ while (stream_remaining) {
+ size_t buffer_size = stream_remaining;
+ uint8_t* buffer = reinterpret_cast<uint8_t*>(sink_->GetBuffer(buffer_size, &buffer_size));
+ if (!buffer) {
+ LOG(ERROR) << "Could not acquire buffer from sink";
+ return false;
+ }
+
+ // Read until we can fill the buffer.
+ uint8_t* buffer_pos = buffer;
+ size_t bytes_to_read = std::min(buffer_size, stream_remaining);
+ while (bytes_to_read) {
+ size_t read;
+ if (!stream_->Read(buffer_pos, bytes_to_read, &read)) {
+ return false;
+ }
+ if (!read) {
+ LOG(ERROR) << "Stream ended prematurely";
+ return false;
+ }
+ if (!sink_->ReturnData(buffer_pos, read)) {
+ LOG(ERROR) << "Could not return buffer to sink";
+ return false;
+ }
+ buffer_pos += read;
+ bytes_to_read -= read;
+ stream_remaining -= read;
+ }
+ }
+ return true;
+}
+
+std::unique_ptr<IDecompressor> IDecompressor::Uncompressed() {
+ return std::unique_ptr<IDecompressor>(new NoDecompressor());
+}
+
+// Read chunks of the COW and incrementally stream them to the decoder.
+class StreamDecompressor : public IDecompressor {
+ public:
+ bool Decompress(size_t output_bytes) override;
+
+ virtual bool Init() = 0;
+ virtual bool DecompressInput(const uint8_t* data, size_t length) = 0;
+ virtual bool Done() = 0;
+
+ protected:
+ bool GetFreshBuffer();
+
+ size_t output_bytes_;
+ size_t stream_remaining_;
+ uint8_t* output_buffer_ = nullptr;
+ size_t output_buffer_remaining_ = 0;
+};
+
+static constexpr size_t kChunkSize = 4096;
+
+bool StreamDecompressor::Decompress(size_t output_bytes) {
+ if (!Init()) {
+ return false;
+ }
+
+ stream_remaining_ = stream_->Size();
+ output_bytes_ = output_bytes;
+
+ uint8_t chunk[kChunkSize];
+ while (stream_remaining_) {
+ size_t read = std::min(stream_remaining_, sizeof(chunk));
+ if (!stream_->Read(chunk, read, &read)) {
+ return false;
+ }
+ if (!read) {
+ LOG(ERROR) << "Stream ended prematurely";
+ return false;
+ }
+ if (!DecompressInput(chunk, read)) {
+ return false;
+ }
+
+ stream_remaining_ -= read;
+
+ if (stream_remaining_ && Done()) {
+ LOG(ERROR) << "Decompressor terminated early";
+ return false;
+ }
+ }
+ if (!Done()) {
+ LOG(ERROR) << "Decompressor expected more bytes";
+ return false;
+ }
+ return true;
+}
+
+bool StreamDecompressor::GetFreshBuffer() {
+ size_t request_size = std::min(output_bytes_, kChunkSize);
+ output_buffer_ =
+ reinterpret_cast<uint8_t*>(sink_->GetBuffer(request_size, &output_buffer_remaining_));
+ if (!output_buffer_) {
+ LOG(ERROR) << "Could not acquire buffer from sink";
+ return false;
+ }
+ return true;
+}
+
+class GzDecompressor final : public StreamDecompressor {
+ public:
+ ~GzDecompressor();
+
+ bool Init() override;
+ bool DecompressInput(const uint8_t* data, size_t length) override;
+ bool Done() override { return ended_; }
+
+ private:
+ z_stream z_ = {};
+ bool ended_ = false;
+};
+
+bool GzDecompressor::Init() {
+ if (int rv = inflateInit(&z_); rv != Z_OK) {
+ LOG(ERROR) << "inflateInit returned error code " << rv;
+ return false;
+ }
+ return true;
+}
+
+GzDecompressor::~GzDecompressor() {
+ inflateEnd(&z_);
+}
+
+bool GzDecompressor::DecompressInput(const uint8_t* data, size_t length) {
+ z_.next_in = reinterpret_cast<Bytef*>(const_cast<uint8_t*>(data));
+ z_.avail_in = length;
+
+ while (z_.avail_in) {
+ // If no more output buffer, grab a new buffer.
+ if (z_.avail_out == 0) {
+ if (!GetFreshBuffer()) {
+ return false;
+ }
+ z_.next_out = reinterpret_cast<Bytef*>(output_buffer_);
+ z_.avail_out = output_buffer_remaining_;
+ }
+
+ // Remember the position of the output buffer so we can call ReturnData.
+ auto avail_out = z_.avail_out;
+
+ // Decompress.
+ int rv = inflate(&z_, Z_NO_FLUSH);
+ if (rv != Z_OK && rv != Z_STREAM_END) {
+ LOG(ERROR) << "inflate returned error code " << rv;
+ return false;
+ }
+
+ size_t returned = avail_out - z_.avail_out;
+ if (!sink_->ReturnData(output_buffer_, returned)) {
+ LOG(ERROR) << "Could not return buffer to sink";
+ return false;
+ }
+ output_buffer_ += returned;
+ output_buffer_remaining_ -= returned;
+
+ if (rv == Z_STREAM_END) {
+ if (z_.avail_in) {
+ LOG(ERROR) << "Gz stream ended prematurely";
+ return false;
+ }
+ ended_ = true;
+ return true;
+ }
+ }
+ return true;
+}
+
+std::unique_ptr<IDecompressor> IDecompressor::Gz() {
+ return std::unique_ptr<IDecompressor>(new GzDecompressor());
+}
+
+} // namespace snapshot
+} // namespace android
diff --git a/fs_mgr/libsnapshot/cow_decompress.h b/fs_mgr/libsnapshot/cow_decompress.h
new file mode 100644
index 0000000..1c8c40d
--- /dev/null
+++ b/fs_mgr/libsnapshot/cow_decompress.h
@@ -0,0 +1,56 @@
+//
+// Copyright (C) 2020 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+#pragma once
+
+#include <libsnapshot/cow_reader.h>
+
+namespace android {
+namespace snapshot {
+
+class IByteStream {
+ public:
+ virtual ~IByteStream() {}
+
+ // Read up to |length| bytes, storing the number of bytes read in the out-
+ // parameter. If the end of the stream is reached, 0 is returned.
+ virtual bool Read(void* buffer, size_t length, size_t* read) = 0;
+
+ // Size of the stream.
+ virtual size_t Size() const = 0;
+};
+
+class IDecompressor {
+ public:
+ virtual ~IDecompressor() {}
+
+ // Factory methods for decompression methods.
+ static std::unique_ptr<IDecompressor> Uncompressed();
+ static std::unique_ptr<IDecompressor> Gz();
+
+ // |output_bytes| is the expected total number of bytes to sink.
+ virtual bool Decompress(size_t output_bytes) = 0;
+
+ void set_stream(IByteStream* stream) { stream_ = stream; }
+ void set_sink(IByteSink* sink) { sink_ = sink; }
+
+ protected:
+ IByteStream* stream_ = nullptr;
+ IByteSink* sink_ = nullptr;
+};
+
+} // namespace snapshot
+} // namespace android
diff --git a/fs_mgr/libsnapshot/cow_reader.cpp b/fs_mgr/libsnapshot/cow_reader.cpp
index 7f77aec..1aea3a9 100644
--- a/fs_mgr/libsnapshot/cow_reader.cpp
+++ b/fs_mgr/libsnapshot/cow_reader.cpp
@@ -17,10 +17,13 @@
#include <sys/types.h>
#include <unistd.h>
+#include <limits>
+
#include <android-base/file.h>
#include <android-base/logging.h>
#include <libsnapshot/cow_reader.h>
#include <zlib.h>
+#include "cow_decompress.h"
namespace android {
namespace snapshot {
@@ -171,7 +174,7 @@
return std::make_unique<CowOpIter>(std::move(ops_buffer), header_.ops_size);
}
-bool CowReader::GetRawBytes(uint64_t offset, void* buffer, size_t len) {
+bool CowReader::GetRawBytes(uint64_t offset, void* buffer, size_t len, size_t* read) {
// Validate the offset, taking care to acknowledge possible overflow of offset+len.
if (offset < sizeof(header_) || offset >= header_.ops_offset || len >= fd_size_ ||
offset + len > header_.ops_offset) {
@@ -182,104 +185,63 @@
PLOG(ERROR) << "lseek to read raw bytes failed";
return false;
}
- if (!android::base::ReadFully(fd_, buffer, len)) {
- PLOG(ERROR) << "read raw bytes failed";
+ ssize_t rv = TEMP_FAILURE_RETRY(::read(fd_.get(), buffer, len));
+ if (rv < 0) {
+ PLOG(ERROR) << "read failed";
return false;
}
+ *read = rv;
return true;
}
+class CowDataStream final : public IByteStream {
+ public:
+ CowDataStream(CowReader* reader, uint64_t offset, size_t data_length)
+ : reader_(reader), offset_(offset), data_length_(data_length) {
+ remaining_ = data_length_;
+ }
+
+ bool Read(void* buffer, size_t length, size_t* read) override {
+ size_t to_read = std::min(length, remaining_);
+ if (!to_read) {
+ *read = 0;
+ return true;
+ }
+ if (!reader_->GetRawBytes(offset_, buffer, to_read, read)) {
+ return false;
+ }
+ offset_ += *read;
+ remaining_ -= *read;
+ return true;
+ }
+
+ size_t Size() const override { return data_length_; }
+
+ private:
+ CowReader* reader_;
+ uint64_t offset_;
+ size_t data_length_;
+ size_t remaining_;
+};
+
bool CowReader::ReadData(const CowOperation& op, IByteSink* sink) {
- uint64_t offset = op.source;
-
+ std::unique_ptr<IDecompressor> decompressor;
switch (op.compression) {
- case kCowCompressNone: {
- size_t remaining = op.data_length;
- while (remaining) {
- size_t amount = remaining;
- void* buffer = sink->GetBuffer(amount, &amount);
- if (!buffer) {
- LOG(ERROR) << "Could not acquire buffer from sink";
- return false;
- }
- if (!GetRawBytes(offset, buffer, amount)) {
- return false;
- }
- if (!sink->ReturnData(buffer, amount)) {
- LOG(ERROR) << "Could not return buffer to sink";
- return false;
- }
- remaining -= amount;
- offset += amount;
- }
- return true;
- }
- case kCowCompressGz: {
- auto input = std::make_unique<Bytef[]>(op.data_length);
- if (!GetRawBytes(offset, input.get(), op.data_length)) {
- return false;
- }
-
- z_stream z = {};
- z.next_in = input.get();
- z.avail_in = op.data_length;
- if (int rv = inflateInit(&z); rv != Z_OK) {
- LOG(ERROR) << "inflateInit returned error code " << rv;
- return false;
- }
-
- while (z.total_out < header_.block_size) {
- // If no more output buffer, grab a new buffer.
- if (z.avail_out == 0) {
- size_t amount = header_.block_size - z.total_out;
- z.next_out = reinterpret_cast<Bytef*>(sink->GetBuffer(amount, &amount));
- if (!z.next_out) {
- LOG(ERROR) << "Could not acquire buffer from sink";
- return false;
- }
- z.avail_out = amount;
- }
-
- // Remember the position of the output buffer so we can call ReturnData.
- auto buffer = z.next_out;
- auto avail_out = z.avail_out;
-
- // Decompress.
- int rv = inflate(&z, Z_NO_FLUSH);
- if (rv != Z_OK && rv != Z_STREAM_END) {
- LOG(ERROR) << "inflate returned error code " << rv;
- return false;
- }
-
- // Return the section of the buffer that was updated.
- if (z.avail_out < avail_out && !sink->ReturnData(buffer, avail_out - z.avail_out)) {
- LOG(ERROR) << "Could not return buffer to sink";
- return false;
- }
-
- if (rv == Z_STREAM_END) {
- // Error if the stream has ended, but we didn't fill the entire block.
- if (z.total_out != header_.block_size) {
- LOG(ERROR) << "Reached gz stream end but did not read a full block of data";
- return false;
- }
- break;
- }
-
- CHECK(rv == Z_OK);
-
- // Error if the stream is expecting more data, but we don't have any to read.
- if (z.avail_in == 0) {
- LOG(ERROR) << "Gz stream ended prematurely";
- return false;
- }
- }
- return true;
- }
+ case kCowCompressNone:
+ decompressor = IDecompressor::Uncompressed();
+ break;
+ case kCowCompressGz:
+ decompressor = IDecompressor::Gz();
+ break;
default:
LOG(ERROR) << "Unknown compression type: " << op.compression;
return false;
}
+
+ CowDataStream stream(this, op.source, op.data_length);
+ decompressor->set_stream(&stream);
+ decompressor->set_sink(sink);
+ return decompressor->Decompress(header_.block_size);
}
} // namespace snapshot
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/cow_reader.h b/fs_mgr/libsnapshot/include/libsnapshot/cow_reader.h
index 9e9f9b8..3998776 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/cow_reader.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/cow_reader.h
@@ -61,9 +61,6 @@
// Return an iterator for retrieving CowOperation entries.
virtual std::unique_ptr<ICowOpIter> GetOpIter() = 0;
- // Get raw bytes from the data section.
- virtual bool GetRawBytes(uint64_t offset, void* buffer, size_t len) = 0;
-
// Get decoded bytes from the data section, handling any decompression.
// All retrieved data is passed to the sink.
virtual bool ReadData(const CowOperation& op, IByteSink* sink) = 0;
@@ -97,9 +94,10 @@
// CowOperation objects. Get() returns a unique CowOperation object
// whose lifeteime depends on the CowOpIter object
std::unique_ptr<ICowOpIter> GetOpIter() override;
- bool GetRawBytes(uint64_t offset, void* buffer, size_t len) override;
bool ReadData(const CowOperation& op, IByteSink* sink) override;
+ bool GetRawBytes(uint64_t offset, void* buffer, size_t len, size_t* read);
+
private:
android::base::unique_fd owned_fd_;
android::base::borrowed_fd fd_;