Merge "improve performance of blockimgdiff" into nyc-dev
diff --git a/tools/releasetools/blockimgdiff.py b/tools/releasetools/blockimgdiff.py
index 1d338ee..eee7e8d 100644
--- a/tools/releasetools/blockimgdiff.py
+++ b/tools/releasetools/blockimgdiff.py
@@ -16,7 +16,9 @@
 
 from collections import deque, OrderedDict
 from hashlib import sha1
+import array
 import common
+import functools
 import heapq
 import itertools
 import multiprocessing
@@ -24,6 +26,7 @@
 import re
 import subprocess
 import threading
+import time
 import tempfile
 
 from rangelib import RangeSet
@@ -204,6 +207,23 @@
             " to " + str(self.tgt_ranges) + ">")
 
 
+@functools.total_ordering
+class HeapItem(object):
+  def __init__(self, item):
+    self.item = item
+    # Negate the score since python's heap is a min-heap and we want
+    # the maximum score.
+    self.score = -item.score
+  def clear(self):
+    self.item = None
+  def __bool__(self):
+    return self.item is None
+  def __eq__(self, other):
+    return self.score == other.score
+  def __le__(self, other):
+    return self.score <= other.score
+
+
 # BlockImageDiff works on two image objects.  An image object is
 # anything that provides the following attributes:
 #
@@ -734,7 +754,7 @@
     # - we write every block we care about exactly once.
 
     # Start with no blocks having been touched yet.
-    touched = RangeSet()
+    touched = array.array("B", "\0" * self.tgt.total_blocks)
 
     # Imagine processing the transfers in order.
     for xf in self.transfers:
@@ -745,14 +765,22 @@
         for _, sr in xf.use_stash:
           x = x.subtract(sr)
 
-      assert not touched.overlaps(x)
-      # Check that the output blocks for this transfer haven't yet been touched.
-      assert not touched.overlaps(xf.tgt_ranges)
-      # Touch all the blocks written by this transfer.
-      touched = touched.union(xf.tgt_ranges)
+      for s, e in x:
+        for i in range(s, e):
+          assert touched[i] == 0
+
+      # Check that the output blocks for this transfer haven't yet
+      # been touched, and touch all the blocks written by this
+      # transfer.
+      for s, e in xf.tgt_ranges:
+        for i in range(s, e):
+          assert touched[i] == 0
+          touched[i] = 1
 
     # Check that we've written every target block.
-    assert touched == self.tgt.care_map
+    for s, e in self.tgt.care_map:
+      for i in range(s, e):
+        assert touched[i] == 1
 
   def ImproveVertexSequence(self):
     print("Improving vertex order...")
@@ -889,6 +917,7 @@
     for xf in self.transfers:
       xf.incoming = xf.goes_after.copy()
       xf.outgoing = xf.goes_before.copy()
+      xf.score = sum(xf.outgoing.values()) - sum(xf.incoming.values())
 
     # We use an OrderedDict instead of just a set so that the output
     # is repeatable; otherwise it would depend on the hash values of
@@ -899,52 +928,67 @@
     s1 = deque()  # the left side of the sequence, built from left to right
     s2 = deque()  # the right side of the sequence, built from right to left
 
-    while G:
+    heap = []
+    for xf in self.transfers:
+      xf.heap_item = HeapItem(xf)
+      heap.append(xf.heap_item)
+    heapq.heapify(heap)
 
+    sinks = set(u for u in G if not u.outgoing)
+    sources = set(u for u in G if not u.incoming)
+
+    def adjust_score(iu, delta):
+      iu.score += delta
+      iu.heap_item.clear()
+      iu.heap_item = HeapItem(iu)
+      heapq.heappush(heap, iu.heap_item)
+
+    while G:
       # Put all sinks at the end of the sequence.
-      while True:
-        sinks = [u for u in G if not u.outgoing]
-        if not sinks:
-          break
+      while sinks:
+        new_sinks = set()
         for u in sinks:
+          if u not in G: continue
           s2.appendleft(u)
           del G[u]
           for iu in u.incoming:
-            del iu.outgoing[u]
+            adjust_score(iu, -iu.outgoing.pop(u))
+            if not iu.outgoing: new_sinks.add(iu)
+        sinks = new_sinks
 
       # Put all the sources at the beginning of the sequence.
-      while True:
-        sources = [u for u in G if not u.incoming]
-        if not sources:
-          break
+      while sources:
+        new_sources = set()
         for u in sources:
+          if u not in G: continue
           s1.append(u)
           del G[u]
           for iu in u.outgoing:
-            del iu.incoming[u]
+            adjust_score(iu, +iu.incoming.pop(u))
+            if not iu.incoming: new_sources.add(iu)
+        sources = new_sources
 
-      if not G:
-        break
+      if not G: break
 
       # Find the "best" vertex to put next.  "Best" is the one that
       # maximizes the net difference in source blocks saved we get by
       # pretending it's a source rather than a sink.
 
-      max_d = None
-      best_u = None
-      for u in G:
-        d = sum(u.outgoing.values()) - sum(u.incoming.values())
-        if best_u is None or d > max_d:
-          max_d = d
-          best_u = u
+      while True:
+        u = heapq.heappop(heap)
+        if u and u.item in G:
+          u = u.item
+          break
 
-      u = best_u
       s1.append(u)
       del G[u]
       for iu in u.outgoing:
-        del iu.incoming[u]
+        adjust_score(iu, +iu.incoming.pop(u))
+        if not iu.incoming: sources.add(iu)
+
       for iu in u.incoming:
-        del iu.outgoing[u]
+        adjust_score(iu, -iu.outgoing.pop(u))
+        if not iu.outgoing: sinks.add(iu)
 
     # Now record the sequence in the 'order' field of each transfer,
     # and by rearranging self.transfers to be in the chosen sequence.
@@ -960,10 +1004,38 @@
 
   def GenerateDigraph(self):
     print("Generating digraph...")
+
+    # Each item of source_ranges will be:
+    #   - None, if that block is not used as a source,
+    #   - a transfer, if one transfer uses it as a source, or
+    #   - a set of transfers.
+    source_ranges = []
+    for b in self.transfers:
+      for s, e in b.src_ranges:
+        if e > len(source_ranges):
+          source_ranges.extend([None] * (e-len(source_ranges)))
+        for i in range(s, e):
+          if source_ranges[i] is None:
+            source_ranges[i] = b
+          else:
+            if not isinstance(source_ranges[i], set):
+              source_ranges[i] = set([source_ranges[i]])
+            source_ranges[i].add(b)
+
     for a in self.transfers:
-      for b in self.transfers:
-        if a is b:
-          continue
+      intersections = set()
+      for s, e in a.tgt_ranges:
+        for i in range(s, e):
+          if i >= len(source_ranges): break
+          b = source_ranges[i]
+          if b is not None:
+            if isinstance(b, set):
+              intersections.update(b)
+            else:
+              intersections.add(b)
+
+      for b in intersections:
+        if a is b: continue
 
         # If the blocks written by A are read by B, then B needs to go before A.
         i = a.tgt_ranges.intersect(b.src_ranges)
@@ -1092,6 +1164,7 @@
     """Assert that all the RangeSets in 'seq' form a partition of the
     'total' RangeSet (ie, they are nonintersecting and their union
     equals 'total')."""
+
     so_far = RangeSet()
     for i in seq:
       assert not so_far.overlaps(i)