Merge "libsnapshot: Refactor cow_reader decompression."
diff --git a/fs_mgr/libsnapshot/Android.bp b/fs_mgr/libsnapshot/Android.bp
index d11d3e4..bdf1da6 100644
--- a/fs_mgr/libsnapshot/Android.bp
+++ b/fs_mgr/libsnapshot/Android.bp
@@ -134,6 +134,7 @@
     ],
     export_include_dirs: ["include"],
     srcs: [
+        "cow_decompress.cpp",
         "cow_reader.cpp",
         "cow_writer.cpp",
     ],
diff --git a/fs_mgr/libsnapshot/cow_decompress.cpp b/fs_mgr/libsnapshot/cow_decompress.cpp
new file mode 100644
index 0000000..f480b85
--- /dev/null
+++ b/fs_mgr/libsnapshot/cow_decompress.cpp
@@ -0,0 +1,211 @@
+//
+// Copyright (C) 2020 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+#include "cow_decompress.h"
+
+#include <utility>
+
+#include <android-base/logging.h>
+#include <zlib.h>
+
+namespace android {
+namespace snapshot {
+
+class NoDecompressor final : public IDecompressor {
+  public:
+    bool Decompress(size_t) override;
+};
+
+bool NoDecompressor::Decompress(size_t) {
+    size_t stream_remaining = stream_->Size();
+    while (stream_remaining) {
+        size_t buffer_size = stream_remaining;
+        uint8_t* buffer = reinterpret_cast<uint8_t*>(sink_->GetBuffer(buffer_size, &buffer_size));
+        if (!buffer) {
+            LOG(ERROR) << "Could not acquire buffer from sink";
+            return false;
+        }
+
+        // Read until we can fill the buffer.
+        uint8_t* buffer_pos = buffer;
+        size_t bytes_to_read = std::min(buffer_size, stream_remaining);
+        while (bytes_to_read) {
+            size_t read;
+            if (!stream_->Read(buffer_pos, bytes_to_read, &read)) {
+                return false;
+            }
+            if (!read) {
+                LOG(ERROR) << "Stream ended prematurely";
+                return false;
+            }
+            if (!sink_->ReturnData(buffer_pos, read)) {
+                LOG(ERROR) << "Could not return buffer to sink";
+                return false;
+            }
+            buffer_pos += read;
+            bytes_to_read -= read;
+            stream_remaining -= read;
+        }
+    }
+    return true;
+}
+
+std::unique_ptr<IDecompressor> IDecompressor::Uncompressed() {
+    return std::unique_ptr<IDecompressor>(new NoDecompressor());
+}
+
+// Read chunks of the COW and incrementally stream them to the decoder.
+class StreamDecompressor : public IDecompressor {
+  public:
+    bool Decompress(size_t output_bytes) override;
+
+    virtual bool Init() = 0;
+    virtual bool DecompressInput(const uint8_t* data, size_t length) = 0;
+    virtual bool Done() = 0;
+
+  protected:
+    bool GetFreshBuffer();
+
+    size_t output_bytes_;
+    size_t stream_remaining_;
+    uint8_t* output_buffer_ = nullptr;
+    size_t output_buffer_remaining_ = 0;
+};
+
+static constexpr size_t kChunkSize = 4096;
+
+bool StreamDecompressor::Decompress(size_t output_bytes) {
+    if (!Init()) {
+        return false;
+    }
+
+    stream_remaining_ = stream_->Size();
+    output_bytes_ = output_bytes;
+
+    uint8_t chunk[kChunkSize];
+    while (stream_remaining_) {
+        size_t read = std::min(stream_remaining_, sizeof(chunk));
+        if (!stream_->Read(chunk, read, &read)) {
+            return false;
+        }
+        if (!read) {
+            LOG(ERROR) << "Stream ended prematurely";
+            return false;
+        }
+        if (!DecompressInput(chunk, read)) {
+            return false;
+        }
+
+        stream_remaining_ -= read;
+
+        if (stream_remaining_ && Done()) {
+            LOG(ERROR) << "Decompressor terminated early";
+            return false;
+        }
+    }
+    if (!Done()) {
+        LOG(ERROR) << "Decompressor expected more bytes";
+        return false;
+    }
+    return true;
+}
+
+bool StreamDecompressor::GetFreshBuffer() {
+    size_t request_size = std::min(output_bytes_, kChunkSize);
+    output_buffer_ =
+            reinterpret_cast<uint8_t*>(sink_->GetBuffer(request_size, &output_buffer_remaining_));
+    if (!output_buffer_) {
+        LOG(ERROR) << "Could not acquire buffer from sink";
+        return false;
+    }
+    return true;
+}
+
+class GzDecompressor final : public StreamDecompressor {
+  public:
+    ~GzDecompressor();
+
+    bool Init() override;
+    bool DecompressInput(const uint8_t* data, size_t length) override;
+    bool Done() override { return ended_; }
+
+  private:
+    z_stream z_ = {};
+    bool ended_ = false;
+};
+
+bool GzDecompressor::Init() {
+    if (int rv = inflateInit(&z_); rv != Z_OK) {
+        LOG(ERROR) << "inflateInit returned error code " << rv;
+        return false;
+    }
+    return true;
+}
+
+GzDecompressor::~GzDecompressor() {
+    inflateEnd(&z_);
+}
+
+bool GzDecompressor::DecompressInput(const uint8_t* data, size_t length) {
+    z_.next_in = reinterpret_cast<Bytef*>(const_cast<uint8_t*>(data));
+    z_.avail_in = length;
+
+    while (z_.avail_in) {
+        // If no more output buffer, grab a new buffer.
+        if (z_.avail_out == 0) {
+            if (!GetFreshBuffer()) {
+                return false;
+            }
+            z_.next_out = reinterpret_cast<Bytef*>(output_buffer_);
+            z_.avail_out = output_buffer_remaining_;
+        }
+
+        // Remember the position of the output buffer so we can call ReturnData.
+        auto avail_out = z_.avail_out;
+
+        // Decompress.
+        int rv = inflate(&z_, Z_NO_FLUSH);
+        if (rv != Z_OK && rv != Z_STREAM_END) {
+            LOG(ERROR) << "inflate returned error code " << rv;
+            return false;
+        }
+
+        size_t returned = avail_out - z_.avail_out;
+        if (!sink_->ReturnData(output_buffer_, returned)) {
+            LOG(ERROR) << "Could not return buffer to sink";
+            return false;
+        }
+        output_buffer_ += returned;
+        output_buffer_remaining_ -= returned;
+
+        if (rv == Z_STREAM_END) {
+            if (z_.avail_in) {
+                LOG(ERROR) << "Gz stream ended prematurely";
+                return false;
+            }
+            ended_ = true;
+            return true;
+        }
+    }
+    return true;
+}
+
+std::unique_ptr<IDecompressor> IDecompressor::Gz() {
+    return std::unique_ptr<IDecompressor>(new GzDecompressor());
+}
+
+}  // namespace snapshot
+}  // namespace android
diff --git a/fs_mgr/libsnapshot/cow_decompress.h b/fs_mgr/libsnapshot/cow_decompress.h
new file mode 100644
index 0000000..1c8c40d
--- /dev/null
+++ b/fs_mgr/libsnapshot/cow_decompress.h
@@ -0,0 +1,56 @@
+//
+// Copyright (C) 2020 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+#pragma once
+
+#include <libsnapshot/cow_reader.h>
+
+namespace android {
+namespace snapshot {
+
+class IByteStream {
+  public:
+    virtual ~IByteStream() {}
+
+    // Read up to |length| bytes, storing the number of bytes read in the out-
+    // parameter. If the end of the stream is reached, 0 is returned.
+    virtual bool Read(void* buffer, size_t length, size_t* read) = 0;
+
+    // Size of the stream.
+    virtual size_t Size() const = 0;
+};
+
+class IDecompressor {
+  public:
+    virtual ~IDecompressor() {}
+
+    // Factory methods for decompression methods.
+    static std::unique_ptr<IDecompressor> Uncompressed();
+    static std::unique_ptr<IDecompressor> Gz();
+
+    // |output_bytes| is the expected total number of bytes to sink.
+    virtual bool Decompress(size_t output_bytes) = 0;
+
+    void set_stream(IByteStream* stream) { stream_ = stream; }
+    void set_sink(IByteSink* sink) { sink_ = sink; }
+
+  protected:
+    IByteStream* stream_ = nullptr;
+    IByteSink* sink_ = nullptr;
+};
+
+}  // namespace snapshot
+}  // namespace android
diff --git a/fs_mgr/libsnapshot/cow_reader.cpp b/fs_mgr/libsnapshot/cow_reader.cpp
index 7f77aec..1aea3a9 100644
--- a/fs_mgr/libsnapshot/cow_reader.cpp
+++ b/fs_mgr/libsnapshot/cow_reader.cpp
@@ -17,10 +17,13 @@
 #include <sys/types.h>
 #include <unistd.h>
 
+#include <limits>
+
 #include <android-base/file.h>
 #include <android-base/logging.h>
 #include <libsnapshot/cow_reader.h>
 #include <zlib.h>
+#include "cow_decompress.h"
 
 namespace android {
 namespace snapshot {
@@ -171,7 +174,7 @@
     return std::make_unique<CowOpIter>(std::move(ops_buffer), header_.ops_size);
 }
 
-bool CowReader::GetRawBytes(uint64_t offset, void* buffer, size_t len) {
+bool CowReader::GetRawBytes(uint64_t offset, void* buffer, size_t len, size_t* read) {
     // Validate the offset, taking care to acknowledge possible overflow of offset+len.
     if (offset < sizeof(header_) || offset >= header_.ops_offset || len >= fd_size_ ||
         offset + len > header_.ops_offset) {
@@ -182,104 +185,63 @@
         PLOG(ERROR) << "lseek to read raw bytes failed";
         return false;
     }
-    if (!android::base::ReadFully(fd_, buffer, len)) {
-        PLOG(ERROR) << "read raw bytes failed";
+    ssize_t rv = TEMP_FAILURE_RETRY(::read(fd_.get(), buffer, len));
+    if (rv < 0) {
+        PLOG(ERROR) << "read failed";
         return false;
     }
+    *read = rv;
     return true;
 }
 
+class CowDataStream final : public IByteStream {
+  public:
+    CowDataStream(CowReader* reader, uint64_t offset, size_t data_length)
+        : reader_(reader), offset_(offset), data_length_(data_length) {
+        remaining_ = data_length_;
+    }
+
+    bool Read(void* buffer, size_t length, size_t* read) override {
+        size_t to_read = std::min(length, remaining_);
+        if (!to_read) {
+            *read = 0;
+            return true;
+        }
+        if (!reader_->GetRawBytes(offset_, buffer, to_read, read)) {
+            return false;
+        }
+        offset_ += *read;
+        remaining_ -= *read;
+        return true;
+    }
+
+    size_t Size() const override { return data_length_; }
+
+  private:
+    CowReader* reader_;
+    uint64_t offset_;
+    size_t data_length_;
+    size_t remaining_;
+};
+
 bool CowReader::ReadData(const CowOperation& op, IByteSink* sink) {
-    uint64_t offset = op.source;
-
+    std::unique_ptr<IDecompressor> decompressor;
     switch (op.compression) {
-        case kCowCompressNone: {
-            size_t remaining = op.data_length;
-            while (remaining) {
-                size_t amount = remaining;
-                void* buffer = sink->GetBuffer(amount, &amount);
-                if (!buffer) {
-                    LOG(ERROR) << "Could not acquire buffer from sink";
-                    return false;
-                }
-                if (!GetRawBytes(offset, buffer, amount)) {
-                    return false;
-                }
-                if (!sink->ReturnData(buffer, amount)) {
-                    LOG(ERROR) << "Could not return buffer to sink";
-                    return false;
-                }
-                remaining -= amount;
-                offset += amount;
-            }
-            return true;
-        }
-        case kCowCompressGz: {
-            auto input = std::make_unique<Bytef[]>(op.data_length);
-            if (!GetRawBytes(offset, input.get(), op.data_length)) {
-                return false;
-            }
-
-            z_stream z = {};
-            z.next_in = input.get();
-            z.avail_in = op.data_length;
-            if (int rv = inflateInit(&z); rv != Z_OK) {
-                LOG(ERROR) << "inflateInit returned error code " << rv;
-                return false;
-            }
-
-            while (z.total_out < header_.block_size) {
-                // If no more output buffer, grab a new buffer.
-                if (z.avail_out == 0) {
-                    size_t amount = header_.block_size - z.total_out;
-                    z.next_out = reinterpret_cast<Bytef*>(sink->GetBuffer(amount, &amount));
-                    if (!z.next_out) {
-                        LOG(ERROR) << "Could not acquire buffer from sink";
-                        return false;
-                    }
-                    z.avail_out = amount;
-                }
-
-                // Remember the position of the output buffer so we can call ReturnData.
-                auto buffer = z.next_out;
-                auto avail_out = z.avail_out;
-
-                // Decompress.
-                int rv = inflate(&z, Z_NO_FLUSH);
-                if (rv != Z_OK && rv != Z_STREAM_END) {
-                    LOG(ERROR) << "inflate returned error code " << rv;
-                    return false;
-                }
-
-                // Return the section of the buffer that was updated.
-                if (z.avail_out < avail_out && !sink->ReturnData(buffer, avail_out - z.avail_out)) {
-                    LOG(ERROR) << "Could not return buffer to sink";
-                    return false;
-                }
-
-                if (rv == Z_STREAM_END) {
-                    // Error if the stream has ended, but we didn't fill the entire block.
-                    if (z.total_out != header_.block_size) {
-                        LOG(ERROR) << "Reached gz stream end but did not read a full block of data";
-                        return false;
-                    }
-                    break;
-                }
-
-                CHECK(rv == Z_OK);
-
-                // Error if the stream is expecting more data, but we don't have any to read.
-                if (z.avail_in == 0) {
-                    LOG(ERROR) << "Gz stream ended prematurely";
-                    return false;
-                }
-            }
-            return true;
-        }
+        case kCowCompressNone:
+            decompressor = IDecompressor::Uncompressed();
+            break;
+        case kCowCompressGz:
+            decompressor = IDecompressor::Gz();
+            break;
         default:
             LOG(ERROR) << "Unknown compression type: " << op.compression;
             return false;
     }
+
+    CowDataStream stream(this, op.source, op.data_length);
+    decompressor->set_stream(&stream);
+    decompressor->set_sink(sink);
+    return decompressor->Decompress(header_.block_size);
 }
 
 }  // namespace snapshot
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/cow_reader.h b/fs_mgr/libsnapshot/include/libsnapshot/cow_reader.h
index 9e9f9b8..3998776 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/cow_reader.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/cow_reader.h
@@ -61,9 +61,6 @@
     // Return an iterator for retrieving CowOperation entries.
     virtual std::unique_ptr<ICowOpIter> GetOpIter() = 0;
 
-    // Get raw bytes from the data section.
-    virtual bool GetRawBytes(uint64_t offset, void* buffer, size_t len) = 0;
-
     // Get decoded bytes from the data section, handling any decompression.
     // All retrieved data is passed to the sink.
     virtual bool ReadData(const CowOperation& op, IByteSink* sink) = 0;
@@ -97,9 +94,10 @@
     // CowOperation objects. Get() returns a unique CowOperation object
     // whose lifeteime depends on the CowOpIter object
     std::unique_ptr<ICowOpIter> GetOpIter() override;
-    bool GetRawBytes(uint64_t offset, void* buffer, size_t len) override;
     bool ReadData(const CowOperation& op, IByteSink* sink) override;
 
+    bool GetRawBytes(uint64_t offset, void* buffer, size_t len, size_t* read);
+
   private:
     android::base::unique_fd owned_fd_;
     android::base::borrowed_fd fd_;