Merge "Return bytes written in CowWriter"
diff --git a/fs_mgr/libsnapshot/cow_api_test.cpp b/fs_mgr/libsnapshot/cow_api_test.cpp
index 3b3fc47..d98fe59 100644
--- a/fs_mgr/libsnapshot/cow_api_test.cpp
+++ b/fs_mgr/libsnapshot/cow_api_test.cpp
@@ -12,11 +12,15 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#include <sys/stat.h>
+
+#include <cstdio>
 #include <iostream>
 #include <memory>
 #include <string_view>
 
 #include <android-base/file.h>
+#include <android-base/logging.h>
 #include <gtest/gtest.h>
 #include <libsnapshot/cow_reader.h>
 #include <libsnapshot/cow_writer.h>
@@ -235,6 +239,34 @@
     ASSERT_EQ(sink.stream(), data);
 }
 
+TEST_F(CowTest, GetSize) {
+    CowOptions options;
+    CowWriter writer(options);
+    if (ftruncate(cow_->fd, 0) < 0) {
+        perror("Fails to set temp file size");
+        FAIL();
+    }
+    ASSERT_TRUE(writer.Initialize(cow_->fd));
+
+    std::string data = "This is some data, believe it";
+    data.resize(options.block_size, '\0');
+
+    ASSERT_TRUE(writer.AddCopy(10, 20));
+    ASSERT_TRUE(writer.AddRawBlocks(50, data.data(), data.size()));
+    ASSERT_TRUE(writer.AddZeroBlocks(51, 2));
+    auto size_before = writer.GetCowSize();
+    ASSERT_TRUE(writer.Finalize());
+    auto size_after = writer.GetCowSize();
+    ASSERT_EQ(size_before, size_after);
+    struct stat buf;
+
+    if (fstat(cow_->fd, &buf) < 0) {
+        perror("Fails to determine size of cow image written");
+        FAIL();
+    }
+    ASSERT_EQ(buf.st_size, writer.GetCowSize());
+}
+
 }  // namespace snapshot
 }  // namespace android
 
diff --git a/fs_mgr/libsnapshot/cow_writer.cpp b/fs_mgr/libsnapshot/cow_writer.cpp
index ea8e534..ff43997 100644
--- a/fs_mgr/libsnapshot/cow_writer.cpp
+++ b/fs_mgr/libsnapshot/cow_writer.cpp
@@ -21,6 +21,7 @@
 
 #include <android-base/file.h>
 #include <android-base/logging.h>
+#include <android-base/unique_fd.h>
 #include <libsnapshot/cow_writer.h>
 #include <openssl/sha.h>
 #include <zlib.h>
@@ -70,7 +71,7 @@
 
     // Headers are not complete, but this ensures the file is at the right
     // position.
-    if (!android::base::WriteFully(fd_, &header_, sizeof(header_))) {
+    if (!WriteFully(fd_, &header_, sizeof(header_))) {
         PLOG(ERROR) << "write failed";
         return false;
     }
@@ -120,7 +121,7 @@
                 LOG(ERROR) << "Compressed block is too large: " << data.size() << " bytes";
                 return false;
             }
-            if (!android::base::WriteFully(fd_, data.data(), data.size())) {
+            if (!WriteFully(fd_, data.data(), data.size())) {
                 PLOG(ERROR) << "AddRawBlocks: write failed";
                 return false;
             }
@@ -136,7 +137,7 @@
         iter += header_.block_size;
     }
 
-    if (!compression_ && !android::base::WriteFully(fd_, data, size)) {
+    if (!compression_ && !WriteFully(fd_, data, size)) {
         PLOG(ERROR) << "AddRawBlocks: write failed";
         return false;
     }
@@ -186,6 +187,10 @@
 }
 
 bool CowWriter::Finalize() {
+    // If both fields are set then Finalize is already called.
+    if (header_.ops_offset > 0 && header_.ops_size > 0) {
+        return true;
+    }
     auto offs = lseek(fd_.get(), 0, SEEK_CUR);
     if (offs < 0) {
         PLOG(ERROR) << "lseek failed";
@@ -197,10 +202,12 @@
     SHA256(ops_.data(), ops_.size(), header_.ops_checksum);
     SHA256(&header_, sizeof(header_), header_.header_checksum);
 
-    if (lseek(fd_.get(), 0, SEEK_SET) < 0) {
-        PLOG(ERROR) << "lseek start failed";
+    if (lseek(fd_.get(), 0, SEEK_SET)) {
+        PLOG(ERROR) << "lseek failed";
         return false;
     }
+    // Header is already written, calling WriteFully will increment
+    // bytes_written_. So use android::base::WriteFully() here.
     if (!android::base::WriteFully(fd_, &header_, sizeof(header_))) {
         PLOG(ERROR) << "write header failed";
         return false;
@@ -209,13 +216,20 @@
         PLOG(ERROR) << "lseek ops failed";
         return false;
     }
-    if (!android::base::WriteFully(fd_, ops_.data(), ops_.size())) {
+    if (!WriteFully(fd_, ops_.data(), ops_.size())) {
         PLOG(ERROR) << "write ops failed";
         return false;
     }
+
+    // clear ops_ so that subsequent calls to GetSize() still works.
+    ops_.clear();
     return true;
 }
 
+size_t CowWriter::GetCowSize() {
+    return bytes_written_ + ops_.size() * sizeof(ops_[0]);
+}
+
 bool CowWriter::GetDataPos(uint64_t* pos) {
     off_t offs = lseek(fd_.get(), 0, SEEK_CUR);
     if (offs < 0) {
@@ -226,5 +240,10 @@
     return true;
 }
 
+bool CowWriter::WriteFully(base::borrowed_fd fd, const void* data, size_t size) {
+    bytes_written_ += size;
+    return android::base::WriteFully(fd, data, size);
+}
+
 }  // namespace snapshot
 }  // namespace android
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/cow_writer.h b/fs_mgr/libsnapshot/include/libsnapshot/cow_writer.h
index 5a2cbd6..8826b7a 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/cow_writer.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/cow_writer.h
@@ -47,6 +47,14 @@
     // Encode a sequence of zeroed blocks. |size| must be a multiple of the block size.
     virtual bool AddZeroBlocks(uint64_t new_block_start, uint64_t num_blocks) = 0;
 
+    // Finalize all COW operations and flush pending writes.
+    // Return true if successful.
+    virtual bool Finalize() = 0;
+
+    // Return 0 if failed, on success return number of bytes the cow image would be
+    // after calling Finalize();
+    virtual size_t GetCowSize() = 0;
+
   protected:
     CowOptions options_;
 };
@@ -63,23 +71,26 @@
     bool AddRawBlocks(uint64_t new_block_start, const void* data, size_t size) override;
     bool AddZeroBlocks(uint64_t new_block_start, uint64_t num_blocks) override;
 
-    // Finalize all COW operations and flush pending writes.
-    bool Finalize();
+    bool Finalize() override;
+
+    size_t GetCowSize() override;
 
   private:
     void SetupHeaders();
     bool GetDataPos(uint64_t* pos);
+    bool WriteFully(base::borrowed_fd fd, const void* data, size_t size);
     std::basic_string<uint8_t> Compress(const void* data, size_t length);
 
   private:
     android::base::unique_fd owned_fd_;
     android::base::borrowed_fd fd_;
-    CowHeader header_;
+    CowHeader header_{};
     int compression_ = 0;
 
     // :TODO: this is not efficient, but stringstream ubsan aborts because some
     // bytes overflow a signed char.
     std::basic_string<uint8_t> ops_;
+    std::atomic<size_t> bytes_written_ = 0;
 };
 
 }  // namespace snapshot