Fix merge sequence failure due to XOR conversion am: f79d4f88e1

Original change: https://android-review.googlesource.com/c/platform/system/update_engine/+/2673358

Change-Id: Iba90964e82f12cc9eb14d7fa8d6ab5ef07242135
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/payload_consumer/xor_extent_writer.cc b/payload_consumer/xor_extent_writer.cc
index 4534c05..fe7eca7 100644
--- a/payload_consumer/xor_extent_writer.cc
+++ b/payload_consumer/xor_extent_writer.cc
@@ -20,15 +20,95 @@
 
 #include "update_engine/common/utils.h"
 #include "update_engine/payload_consumer/xor_extent_writer.h"
+#include "update_engine/payload_generator/extent_ranges.h"
 #include "update_engine/payload_generator/extent_utils.h"
+#include "update_engine/update_metadata.pb.h"
 
 namespace chromeos_update_engine {
+bool XORExtentWriter::WriteXorCowOp(const uint8_t* bytes,
+                                    const size_t size,
+                                    const Extent& xor_ext,
+                                    const size_t src_offset) {
+  xor_block_data.resize(BlockSize() * xor_ext.num_blocks());
+  const auto src_block = src_offset / BlockSize();
+  ssize_t bytes_read = 0;
+  TEST_AND_RETURN_FALSE_ERRNO(utils::PReadAll(source_fd_,
+                                              xor_block_data.data(),
+                                              xor_block_data.size(),
+                                              src_offset,
+                                              &bytes_read));
+  if (bytes_read != static_cast<ssize_t>(xor_block_data.size())) {
+    LOG(ERROR) << "bytes_read: " << bytes_read << ", expected to read "
+               << xor_block_data.size() << " at block " << src_block
+               << " offset " << src_offset % BlockSize();
+    return false;
+  }
+
+  std::transform(xor_block_data.cbegin(),
+                 xor_block_data.cbegin() + xor_block_data.size(),
+                 bytes,
+                 xor_block_data.begin(),
+                 std::bit_xor<unsigned char>{});
+  TEST_AND_RETURN_FALSE(cow_writer_->AddXorBlocks(xor_ext.start_block(),
+                                                  xor_block_data.data(),
+                                                  xor_block_data.size(),
+                                                  src_block,
+                                                  src_offset % BlockSize()));
+  return true;
+}
+
+bool XORExtentWriter::WriteXorExtent(const uint8_t* bytes,
+                                     const size_t size,
+                                     const Extent& xor_ext,
+                                     const CowMergeOperation* merge_op) {
+  const auto src_block = merge_op->src_extent().start_block() +
+                         xor_ext.start_block() -
+                         merge_op->dst_extent().start_block();
+  const auto read_end_offset =
+      (src_block + xor_ext.num_blocks()) * BlockSize() + merge_op->src_offset();
+  const auto is_out_of_bound_read =
+      read_end_offset > partition_size_ && partition_size_ != 0;
+  const auto oob_bytes =
+      is_out_of_bound_read ? read_end_offset - partition_size_ : 0;
+  if (is_out_of_bound_read) {
+    if (oob_bytes >= BlockSize()) {
+      LOG(ERROR) << "XOR op overflowed source partition by more than "
+                 << BlockSize() << ", " << xor_ext << ", " << merge_op
+                 << ", out of bound bytes: " << oob_bytes
+                 << ", partition size: " << partition_size_;
+      return false;
+    }
+    if (oob_bytes > merge_op->src_offset()) {
+      LOG(ERROR) << "XOR op overflowed source offset, out of bound bytes: "
+                 << oob_bytes << ", source offset: " << merge_op->src_offset();
+    }
+    Extent non_oob_extent =
+        ExtentForRange(xor_ext.start_block(), xor_ext.num_blocks() - 1);
+    if (non_oob_extent.num_blocks() > 0) {
+      TEST_AND_RETURN_FALSE(
+          WriteXorCowOp(bytes,
+                        BlockSize() * non_oob_extent.num_blocks(),
+                        non_oob_extent,
+                        src_block * BlockSize() + merge_op->src_offset()));
+    }
+    const Extent last_block =
+        ExtentForRange(xor_ext.start_block() + xor_ext.num_blocks() - 1, 1);
+    TEST_AND_RETURN_FALSE(
+        WriteXorCowOp(bytes + (xor_ext.num_blocks() - 1) * BlockSize(),
+                      BlockSize(),
+                      last_block,
+                      (src_block + xor_ext.num_blocks() - 1) * BlockSize()));
+    return true;
+  }
+  TEST_AND_RETURN_FALSE(WriteXorCowOp(
+      bytes, size, xor_ext, src_block * BlockSize() + merge_op->src_offset()));
+  return true;
+}
 
 // Returns true on success.
 bool XORExtentWriter::WriteExtent(const void* bytes,
                                   const Extent& extent,
                                   const size_t size) {
-  brillo::Blob xor_block_data;
   const auto xor_extents = xor_map_.GetIntersectingExtents(extent);
   for (const auto& xor_ext : xor_extents) {
     const auto merge_op_opt = xor_map_.Get(xor_ext);
@@ -60,52 +140,16 @@
                  << xor_ext << " xor_map extent: " << merge_op->dst_extent();
       return false;
     }
-    const auto src_offset = merge_op->src_offset();
-    const auto src_block = merge_op->src_extent().start_block() +
-                           xor_ext.start_block() -
-                           merge_op->dst_extent().start_block();
     const auto i = xor_ext.start_block() - extent.start_block();
     const auto dst_block_data =
         static_cast<const unsigned char*>(bytes) + i * BlockSize();
-    const auto is_out_of_bound_read =
-        (src_block + xor_ext.num_blocks()) * BlockSize() + src_offset >
-            partition_size_ &&
-        partition_size_ != 0;
-    if (is_out_of_bound_read) {
-      LOG(INFO) << "Getting partial read for last block, converting "
-                   "XOR operation to a regular replace "
-                << xor_ext;
-      TEST_AND_RETURN_FALSE(
-          cow_writer_->AddRawBlocks(xor_ext.start_block(),
-                                    dst_block_data,
-                                    xor_ext.num_blocks() * BlockSize()));
-      continue;
-    }
-    xor_block_data.resize(BlockSize() * xor_ext.num_blocks());
-    ssize_t bytes_read = 0;
-    TEST_AND_RETURN_FALSE_ERRNO(
-        utils::PReadAll(source_fd_,
-                        xor_block_data.data(),
-                        xor_block_data.size(),
-                        src_offset + src_block * BlockSize(),
-                        &bytes_read));
-    if (bytes_read != static_cast<ssize_t>(xor_block_data.size())) {
-      LOG(ERROR) << "bytes_read: " << bytes_read << ", expected to read "
-                 << xor_block_data.size() << " at block " << src_block
-                 << " offset " << src_offset;
+    if (!WriteXorExtent(dst_block_data,
+                        xor_ext.num_blocks() * BlockSize(),
+                        xor_ext,
+                        merge_op)) {
+      LOG(ERROR) << "Failed to write XOR extent " << xor_ext;
       return false;
     }
-
-    std::transform(xor_block_data.cbegin(),
-                   xor_block_data.cbegin() + xor_block_data.size(),
-                   dst_block_data,
-                   xor_block_data.begin(),
-                   std::bit_xor<unsigned char>{});
-    TEST_AND_RETURN_FALSE(cow_writer_->AddXorBlocks(xor_ext.start_block(),
-                                                    xor_block_data.data(),
-                                                    xor_block_data.size(),
-                                                    src_block,
-                                                    src_offset));
   }
   const auto replace_extents = xor_map_.GetNonIntersectingExtents(extent);
   return WriteReplaceExtents(replace_extents, extent, bytes, size);
diff --git a/payload_consumer/xor_extent_writer.h b/payload_consumer/xor_extent_writer.h
index 57c99c2..2074ee2 100644
--- a/payload_consumer/xor_extent_writer.h
+++ b/payload_consumer/xor_extent_writer.h
@@ -56,10 +56,19 @@
                            const Extent& extent,
                            const void* bytes,
                            size_t size);
+  bool WriteXorExtent(const uint8_t* bytes,
+                      const size_t size,
+                      const Extent& xor_ext,
+                      const CowMergeOperation* merge_op);
+  bool WriteXorCowOp(const uint8_t* bytes,
+                     const size_t size,
+                     const Extent& xor_ext,
+                     size_t src_offset);
   const google::protobuf::RepeatedPtrField<Extent>& src_extents_;
   const FileDescriptorPtr source_fd_;
   const ExtentMap<const CowMergeOperation*>& xor_map_;
   android::snapshot::ICowWriter* cow_writer_;
+  std::vector<uint8_t> xor_block_data;
   const size_t partition_size_;
 };
 
diff --git a/payload_consumer/xor_extent_writer_unittest.cc b/payload_consumer/xor_extent_writer_unittest.cc
index 45796a6..827030a 100644
--- a/payload_consumer/xor_extent_writer_unittest.cc
+++ b/payload_consumer/xor_extent_writer_unittest.cc
@@ -14,6 +14,7 @@
 // limitations under the License.
 //
 
+#include <algorithm>
 #include <memory>
 
 #include <unistd.h>
@@ -22,7 +23,7 @@
 #include <gtest/gtest.h>
 #include <libsnapshot/mock_cow_writer.h>
 
-#include "common/utils.h"
+#include "update_engine/common/utils.h"
 #include "update_engine/payload_consumer/extent_map.h"
 #include "update_engine/payload_consumer/file_descriptor.h"
 #include "update_engine/payload_consumer/xor_extent_writer.h"
@@ -44,11 +45,13 @@
     ASSERT_EQ(ftruncate64(source_part_.fd, kBlockSize * NUM_BLOCKS), 0);
     ASSERT_EQ(ftruncate64(target_part_.fd, kBlockSize * NUM_BLOCKS), 0);
 
-    // Fill source part with 1s, as we are computing XOR between source and
-    // target data later.
+    // Fill source part with arbitrary data, as we are computing XOR between
+    // source and target data later.
     ASSERT_EQ(lseek(source_part_.fd, 0, SEEK_SET), 0);
     brillo::Blob buffer(kBlockSize);
-    std::fill(buffer.begin(), buffer.end(), 1);
+    for (size_t i = 0; i < kBlockSize; i++) {
+      buffer[i] = i & 0xFF;
+    }
     for (size_t i = 0; i < NUM_BLOCKS; i++) {
       ASSERT_EQ(write(source_part_.fd, buffer.data(), buffer.size()),
                 static_cast<ssize_t>(buffer.size()));
@@ -195,12 +198,13 @@
   // [12-14] => [320-322], [20-22] => [420-422], [NUM_BLOCKS-3] => [2-5]
 
   // merge op:
-  // [NUM_BLOCKS-1] => [2-3]
+  // [NUM_BLOCKS-1] => [2]
 
   // Expected result:
-  // [12-16] should be REPLACE blocks
+  // [320-322] should be REPLACE blocks
   // [420-422] should be REPLACE blocks
-  // [2-4] should be REPLACE blocks
+  // [2] should be XOR blocks, with 0 offset to avoid out of bound read
+  // [3-5] should be REPLACE BLOCKS
 
   auto zeros = utils::GetReadonlyZeroBlock(kBlockSize * 9);
   EXPECT_CALL(cow_writer_, AddRawBlocks(320, zeros->data(), kBlockSize * 3))
@@ -209,15 +213,84 @@
               AddRawBlocks(420, zeros->data() + 3 * kBlockSize, kBlockSize * 3))
       .WillOnce(Return(true));
 
-  EXPECT_CALL(cow_writer_,
-              AddRawBlocks(2, zeros->data() + 6 * kBlockSize, kBlockSize))
+  EXPECT_CALL(cow_writer_, AddXorBlocks(2, _, kBlockSize, NUM_BLOCKS - 1, 0))
       .WillOnce(Return(true));
   EXPECT_CALL(cow_writer_,
-              AddRawBlocks(3, zeros->data() + 7 * kBlockSize, kBlockSize * 2))
+              AddRawBlocks(3, zeros->data() + kBlockSize * 7, kBlockSize * 2))
       .WillOnce(Return(true));
 
   ASSERT_TRUE(writer_.Init(op_.dst_extents(), kBlockSize));
   ASSERT_TRUE(writer_.Write(zeros->data(), zeros->size()));
 }
 
+TEST_F(XorExtentWriterTest, LastMultiBlockTest) {
+  constexpr auto COW_XOR = CowMergeOperation::COW_XOR;
+
+  const auto op3 = CreateCowMergeOperation(
+      ExtentForRange(NUM_BLOCKS - 4, 4), ExtentForRange(2, 4), COW_XOR, 777);
+  ASSERT_TRUE(xor_map_.AddExtent(op3.dst_extent(), &op3));
+
+  *op_.add_src_extents() = ExtentForRange(NUM_BLOCKS - 4, 4);
+  *op_.add_dst_extents() = ExtentForRange(2, 4);
+  XORExtentWriter writer_{
+      op_, source_fd_, &cow_writer_, xor_map_, NUM_BLOCKS * kBlockSize};
+
+  // OTA op:
+  // [NUM_BLOCKS-4] => [2-5]
+
+  // merge op:
+  // [NUM_BLOCKS-4] => [2-5]
+
+  // Expected result:
+  // [12-16] should be REPLACE blocks
+  // [420-422] should be REPLACE blocks
+  // [2-3] should be XOR blocks
+  // [4] should be XOR blocks with 0 offset to avoid out of bound read
+
+  // Send arbitrary data, just to confirm that XORExtentWriter did XOR the
+  // source data with target data
+  std::vector<uint8_t> op_data(kBlockSize * 4);
+  for (size_t i = 0; i < op_data.size(); i++) {
+    if (i % kBlockSize == 0) {
+      op_data[i] = 1;
+    } else {
+      op_data[i] = (op_data[i - 1] * 3) & 0xFF;
+    }
+  }
+  auto&& verify_xor_data = [source_fd_(source_fd_), &op_data](
+                               uint32_t new_block_start,
+                               const void* data,
+                               size_t size,
+                               uint32_t old_block,
+                               uint16_t offset) -> bool {
+    std::vector<uint8_t> source_data(size);
+    ssize_t bytes_read{};
+    TEST_AND_RETURN_FALSE_ERRNO(utils::PReadAll(source_fd_,
+                                                source_data.data(),
+                                                source_data.size(),
+                                                old_block * kBlockSize + offset,
+                                                &bytes_read));
+    TEST_EQ(bytes_read, static_cast<ssize_t>(source_data.size()));
+    std::transform(source_data.begin(),
+                   source_data.end(),
+                   static_cast<const uint8_t*>(data),
+                   source_data.begin(),
+                   std::bit_xor<uint8_t>{});
+    if (memcmp(source_data.data(), op_data.data(), source_data.size()) != 0) {
+      LOG(ERROR) << "XOR data received does not appear to be an XOR between "
+                    "source and target data";
+      return false;
+    }
+    return true;
+  };
+  EXPECT_CALL(cow_writer_,
+              AddXorBlocks(2, _, kBlockSize * 3, NUM_BLOCKS - 4, 777))
+      .WillOnce(verify_xor_data);
+  EXPECT_CALL(cow_writer_, AddXorBlocks(5, _, kBlockSize, NUM_BLOCKS - 1, 0))
+      .WillOnce(verify_xor_data);
+
+  ASSERT_TRUE(writer_.Init(op_.dst_extents(), kBlockSize));
+  ASSERT_TRUE(writer_.Write(op_data.data(), op_data.size()));
+}
+
 }  // namespace chromeos_update_engine