Implement the topology sort in CreateCowMergeOperation am: 31ad11922c

Original change: https://android-review.googlesource.com/c/platform/system/update_engine/+/1396952

Change-Id: Id5e2370fb3458352e572c55b5c65e0b51be98158
diff --git a/payload_generator/merge_sequence_generator.cc b/payload_generator/merge_sequence_generator.cc
index dd801d6..eaffeac 100644
--- a/payload_generator/merge_sequence_generator.cc
+++ b/payload_generator/merge_sequence_generator.cc
@@ -16,6 +16,8 @@
 
 #include "update_engine/payload_generator/merge_sequence_generator.h"
 
+#include <algorithm>
+
 #include "update_engine/payload_generator/extent_utils.h"
 
 namespace chromeos_update_engine {
@@ -85,6 +87,7 @@
     }
   }
 
+  std::sort(sequence.begin(), sequence.end());
   return std::unique_ptr<MergeSequenceGenerator>(
       new MergeSequenceGenerator(sequence));
 }
@@ -92,11 +95,154 @@
 bool MergeSequenceGenerator::FindDependency(
     std::map<CowMergeOperation, std::set<CowMergeOperation>>* result) const {
   CHECK(result);
+  LOG(INFO) << "Finding dependencies";
+
+  // Since the OTA operation may reuse some source blocks, use the binary
+  // search on sorted dst extents to find overlaps.
+  std::map<CowMergeOperation, std::set<CowMergeOperation>> merge_after;
+  for (const auto& op : operations_) {
+    // lower bound (inclusive): dst extent's end block >= src extent's start
+    // block.
+    const auto lower_it = std::lower_bound(
+        operations_.begin(),
+        operations_.end(),
+        op,
+        [](const CowMergeOperation& it, const CowMergeOperation& op) {
+          auto dst_end_block =
+              it.dst_extent().start_block() + it.dst_extent().num_blocks() - 1;
+          return dst_end_block < op.src_extent().start_block();
+        });
+    // upper bound: dst extent's start block > src extent's end block
+    const auto upper_it = std::upper_bound(
+        lower_it,
+        operations_.end(),
+        op,
+        [](const CowMergeOperation& op, const CowMergeOperation& it) {
+          auto src_end_block =
+              op.src_extent().start_block() + op.src_extent().num_blocks() - 1;
+          return src_end_block < it.dst_extent().start_block();
+        });
+
+    // TODO(xunchang) skip inserting the empty set to merge_after.
+    if (lower_it == upper_it) {
+      merge_after.insert({op, {}});
+    } else {
+      std::set<CowMergeOperation> operations(lower_it, upper_it);
+      auto it = operations.find(op);
+      if (it != operations.end()) {
+        LOG(INFO) << "Self overlapping " << op;
+        operations.erase(it);
+      }
+      auto ret = merge_after.emplace(op, std::move(operations));
+      // Check the insertion indeed happens.
+      CHECK(ret.second);
+    }
+  }
+
+  *result = std::move(merge_after);
   return true;
 }
 
 bool MergeSequenceGenerator::Generate(
     std::vector<CowMergeOperation>* sequence) const {
+  sequence->clear();
+  std::map<CowMergeOperation, std::set<CowMergeOperation>> merge_after;
+  if (!FindDependency(&merge_after)) {
+    LOG(ERROR) << "Failed to find dependencies";
+    return false;
+  }
+
+  LOG(INFO) << "Generating sequence";
+
+  // Use the non-DFS version of the topology sort. So we can control the
+  // operations to discard to break cycles; thus yielding a deterministic
+  // sequence.
+  std::map<CowMergeOperation, int> incoming_edges;
+  for (const auto& it : merge_after) {
+    for (const auto& blocked : it.second) {
+      // Value is default initialized to 0.
+      incoming_edges[blocked] += 1;
+    }
+  }
+
+  std::set<CowMergeOperation> free_operations;
+  for (const auto& op : operations_) {
+    if (incoming_edges.find(op) == incoming_edges.end()) {
+      free_operations.insert(op);
+    }
+  }
+
+  std::vector<CowMergeOperation> merge_sequence;
+  std::set<CowMergeOperation> convert_to_raw;
+  while (!incoming_edges.empty()) {
+    if (!free_operations.empty()) {
+      merge_sequence.insert(
+          merge_sequence.end(), free_operations.begin(), free_operations.end());
+    } else {
+      auto to_convert = incoming_edges.begin()->first;
+      free_operations.insert(to_convert);
+      convert_to_raw.insert(to_convert);
+      LOG(INFO) << "Converting operation to raw " << to_convert;
+    }
+
+    std::set<CowMergeOperation> next_free_operations;
+    for (const auto& op : free_operations) {
+      incoming_edges.erase(op);
+
+      // Now that this particular operation is merged, other operations blocked
+      // by this one may be free. Decrement the count of blocking operations,
+      // and set up the free operations for the next iteration.
+      for (const auto& blocked : merge_after[op]) {
+        auto it = incoming_edges.find(blocked);
+        if (it == incoming_edges.end()) {
+          continue;
+        }
+
+        auto blocking_transfer_count = &it->second;
+        if (*blocking_transfer_count <= 0) {
+          LOG(ERROR) << "Unexpected count in merge after map "
+                     << blocking_transfer_count;
+          return false;
+        }
+        // This operation is no longer blocked by anyone. Add it to the merge
+        // sequence in the next iteration.
+        *blocking_transfer_count -= 1;
+        if (*blocking_transfer_count == 0) {
+          next_free_operations.insert(blocked);
+        }
+      }
+    }
+
+    LOG(INFO) << "Remaining transfers " << incoming_edges.size()
+              << ", free transfers " << free_operations.size()
+              << ", merge_sequence size " << merge_sequence.size();
+    free_operations = std::move(next_free_operations);
+  }
+
+  if (!free_operations.empty()) {
+    merge_sequence.insert(
+        merge_sequence.end(), free_operations.begin(), free_operations.end());
+  }
+
+  CHECK_EQ(operations_.size(), merge_sequence.size() + convert_to_raw.size());
+
+  size_t blocks_in_sequence = 0;
+  for (const CowMergeOperation& transfer : merge_sequence) {
+    blocks_in_sequence += transfer.dst_extent().num_blocks();
+  }
+
+  size_t blocks_in_raw = 0;
+  for (const CowMergeOperation& transfer : convert_to_raw) {
+    blocks_in_raw += transfer.dst_extent().num_blocks();
+  }
+
+  LOG(INFO) << "Blocks in merge sequence " << blocks_in_sequence
+            << ", blocks in raw " << blocks_in_raw;
+  if (!ValidateSequence(merge_sequence)) {
+    return false;
+  }
+
+  *sequence = std::move(merge_sequence);
   return true;
 }
 
diff --git a/payload_generator/merge_sequence_generator_unittest.cc b/payload_generator/merge_sequence_generator_unittest.cc
index 83cf78f..567ede1 100644
--- a/payload_generator/merge_sequence_generator_unittest.cc
+++ b/payload_generator/merge_sequence_generator_unittest.cc
@@ -14,20 +14,15 @@
 // limitations under the License.
 //
 
-#include <string>
+#include <algorithm>
 #include <vector>
 
 #include <gtest/gtest.h>
 
-#include "update_engine/common/test_utils.h"
 #include "update_engine/payload_consumer/payload_constants.h"
 #include "update_engine/payload_generator/extent_utils.h"
 #include "update_engine/payload_generator/merge_sequence_generator.h"
 
-using chromeos_update_engine::test_utils::FillWithData;
-using std::string;
-using std::vector;
-
 namespace chromeos_update_engine {
 class MergeSequenceGeneratorTest : public ::testing::Test {
  protected:
@@ -35,6 +30,23 @@
                        const std::vector<CowMergeOperation>& expected) {
     ASSERT_EQ(expected, generator->operations_);
   }
+
+  void FindDependency(
+      std::vector<CowMergeOperation> transfers,
+      std::map<CowMergeOperation, std::set<CowMergeOperation>>* result) {
+    std::sort(transfers.begin(), transfers.end());
+    MergeSequenceGenerator generator(std::move(transfers));
+    ASSERT_TRUE(generator.FindDependency(result));
+  }
+
+  void GenerateSequence(std::vector<CowMergeOperation> transfers,
+                        const std::vector<CowMergeOperation>& expected) {
+    std::sort(transfers.begin(), transfers.end());
+    MergeSequenceGenerator generator(std::move(transfers));
+    std::vector<CowMergeOperation> sequence;
+    ASSERT_TRUE(generator.Generate(&sequence));
+    ASSERT_EQ(expected, sequence);
+  }
 };
 
 TEST_F(MergeSequenceGeneratorTest, Create) {
@@ -78,6 +90,47 @@
   VerifyTransfers(generator.get(), expected);
 }
 
+TEST_F(MergeSequenceGeneratorTest, FindDependency) {
+  std::vector<CowMergeOperation> transfers = {
+      CreateCowMergeOperation(ExtentForRange(10, 10), ExtentForRange(15, 10)),
+      CreateCowMergeOperation(ExtentForRange(40, 10), ExtentForRange(50, 10)),
+  };
+
+  std::map<CowMergeOperation, std::set<CowMergeOperation>> merge_after;
+  FindDependency(transfers, &merge_after);
+  ASSERT_EQ(std::set<CowMergeOperation>(), merge_after.at(transfers[0]));
+  ASSERT_EQ(std::set<CowMergeOperation>(), merge_after.at(transfers[1]));
+
+  transfers = {
+      CreateCowMergeOperation(ExtentForRange(10, 10), ExtentForRange(25, 10)),
+      CreateCowMergeOperation(ExtentForRange(24, 5), ExtentForRange(35, 5)),
+      CreateCowMergeOperation(ExtentForRange(30, 10), ExtentForRange(15, 10)),
+  };
+
+  FindDependency(transfers, &merge_after);
+  ASSERT_EQ(std::set<CowMergeOperation>({transfers[2]}),
+            merge_after.at(transfers[0]));
+  ASSERT_EQ(std::set<CowMergeOperation>({transfers[0], transfers[2]}),
+            merge_after.at(transfers[1]));
+  ASSERT_EQ(std::set<CowMergeOperation>({transfers[0], transfers[1]}),
+            merge_after.at(transfers[2]));
+}
+
+TEST_F(MergeSequenceGeneratorTest, FindDependency_ReusedSourceBlocks) {
+  std::vector<CowMergeOperation> transfers = {
+      CreateCowMergeOperation(ExtentForRange(5, 10), ExtentForRange(15, 10)),
+      CreateCowMergeOperation(ExtentForRange(6, 5), ExtentForRange(30, 5)),
+      CreateCowMergeOperation(ExtentForRange(50, 5), ExtentForRange(5, 5)),
+  };
+
+  std::map<CowMergeOperation, std::set<CowMergeOperation>> merge_after;
+  FindDependency(transfers, &merge_after);
+  ASSERT_EQ(std::set<CowMergeOperation>({transfers[2]}),
+            merge_after.at(transfers[0]));
+  ASSERT_EQ(std::set<CowMergeOperation>({transfers[2]}),
+            merge_after.at(transfers[1]));
+}
+
 TEST_F(MergeSequenceGeneratorTest, ValidateSequence) {
   std::vector<CowMergeOperation> transfers = {
       CreateCowMergeOperation(ExtentForRange(10, 10), ExtentForRange(15, 10)),
@@ -94,4 +147,50 @@
   ASSERT_FALSE(MergeSequenceGenerator::ValidateSequence(transfers));
 }
 
+TEST_F(MergeSequenceGeneratorTest, GenerateSequenceNoCycles) {
+  std::vector<CowMergeOperation> transfers = {
+      CreateCowMergeOperation(ExtentForRange(10, 10), ExtentForRange(15, 10)),
+      // file3 should merge before file2
+      CreateCowMergeOperation(ExtentForRange(40, 5), ExtentForRange(25, 5)),
+      CreateCowMergeOperation(ExtentForRange(25, 10), ExtentForRange(30, 10)),
+  };
+
+  std::vector<CowMergeOperation> expected{
+      transfers[0], transfers[2], transfers[1]};
+  GenerateSequence(transfers, expected);
+}
+
+TEST_F(MergeSequenceGeneratorTest, GenerateSequenceWithCycles) {
+  std::vector<CowMergeOperation> transfers = {
+      CreateCowMergeOperation(ExtentForRange(25, 10), ExtentForRange(30, 10)),
+      CreateCowMergeOperation(ExtentForRange(30, 10), ExtentForRange(40, 10)),
+      CreateCowMergeOperation(ExtentForRange(40, 10), ExtentForRange(25, 10)),
+      CreateCowMergeOperation(ExtentForRange(10, 10), ExtentForRange(15, 10)),
+  };
+
+  // file 1,2,3 form a cycle. And file3, whose dst ext has smallest offset, will
+  // be converted to raw blocks
+  std::vector<CowMergeOperation> expected{
+      transfers[3], transfers[1], transfers[0]};
+  GenerateSequence(transfers, expected);
+}
+
+TEST_F(MergeSequenceGeneratorTest, GenerateSequenceMultipleCycles) {
+  std::vector<CowMergeOperation> transfers = {
+      // cycle 1
+      CreateCowMergeOperation(ExtentForRange(10, 10), ExtentForRange(25, 10)),
+      CreateCowMergeOperation(ExtentForRange(24, 5), ExtentForRange(35, 5)),
+      CreateCowMergeOperation(ExtentForRange(30, 10), ExtentForRange(15, 10)),
+      // cycle 2
+      CreateCowMergeOperation(ExtentForRange(55, 10), ExtentForRange(60, 10)),
+      CreateCowMergeOperation(ExtentForRange(60, 10), ExtentForRange(70, 10)),
+      CreateCowMergeOperation(ExtentForRange(70, 10), ExtentForRange(55, 10)),
+  };
+
+  // file 3, 6 will be converted to raw.
+  std::vector<CowMergeOperation> expected{
+      transfers[1], transfers[0], transfers[4], transfers[3]};
+  GenerateSequence(transfers, expected);
+}
+
 }  // namespace chromeos_update_engine