blockimgdiff: Factor out the diff_worker

We will call it at an earlier time to compute the patch size; and
choose the transfers to convert to 'new'.

Bug: 120561199
Test: Generate an incremental update on shiner
Change-Id: I29a0c8e75c9e5b66a266c1387186692a86fcbe43
diff --git a/tools/releasetools/blockimgdiff.py b/tools/releasetools/blockimgdiff.py
index 2d20e23..b5e01d3 100644
--- a/tools/releasetools/blockimgdiff.py
+++ b/tools/releasetools/blockimgdiff.py
@@ -26,7 +26,8 @@
 import re
 import sys
 import threading
-from collections import deque, OrderedDict
+import zlib
+from collections import deque, namedtuple, OrderedDict
 from hashlib import sha1
 
 import common
@@ -36,8 +37,12 @@
 
 logger = logging.getLogger(__name__)
 
+# The tuple contains the style and bytes of a bsdiff|imgdiff patch.
+PatchInfo = namedtuple("PatchInfo", ["imgdiff", "content"])
+
 
 def compute_patch(srcfile, tgtfile, imgdiff=False):
+  """Calls bsdiff|imgdiff to compute the patch data, returns a PatchInfo."""
   patchfile = common.MakeTempFile(prefix='patch-')
 
   cmd = ['imgdiff', '-z'] if imgdiff else ['bsdiff']
@@ -52,7 +57,7 @@
     raise ValueError(output)
 
   with open(patchfile, 'rb') as f:
-    return f.read()
+    return PatchInfo(imgdiff, f.read())
 
 
 class Image(object):
@@ -203,17 +208,17 @@
     self.id = len(by_id)
     by_id.append(self)
 
-    self._patch = None
+    self._patch_info = None
 
   @property
-  def patch(self):
-    return self._patch
+  def patch_info(self):
+    return self._patch_info
 
-  @patch.setter
-  def patch(self, patch):
-    if patch:
+  @patch_info.setter
+  def patch_info(self, info):
+    if info:
       assert self.style == "diff"
-    self._patch = patch
+    self._patch_info = info
 
   def NetStashChange(self):
     return (sum(sr.size() for (_, sr) in self.stash_before) -
@@ -224,7 +229,7 @@
     self.use_stash = []
     self.style = "new"
     self.src_ranges = RangeSet()
-    self.patch = None
+    self.patch_info = None
 
   def __str__(self):
     return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
@@ -462,16 +467,7 @@
     self.AbbreviateSourceNames()
     self.FindTransfers()
 
-    # Find the ordering dependencies among transfers (this is O(n^2)
-    # in the number of transfers).
-    self.GenerateDigraph()
-    # Find a sequence of transfers that satisfies as many ordering
-    # dependencies as possible (heuristically).
-    self.FindVertexSequence()
-    # Fix up the ordering dependencies that the sequence didn't
-    # satisfy.
-    self.ReverseBackwardEdges()
-    self.ImproveVertexSequence()
+    self.FindSequenceForTransfers()
 
     # Ensure the runtime stash size is under the limit.
     if common.OPTIONS.cache_size is not None:
@@ -829,7 +825,7 @@
             # These are identical; we don't need to generate a patch,
             # just issue copy commands on the device.
             xf.style = "move"
-            xf.patch = None
+            xf.patch_info = None
             tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
             if xf.src_ranges != xf.tgt_ranges:
               logger.info(
@@ -839,11 +835,10 @@
                       xf.tgt_name + " (from " + xf.src_name + ")"),
                   str(xf.tgt_ranges), str(xf.src_ranges))
           else:
-            if xf.patch:
-              # We have already generated the patch with imgdiff, while
-              # splitting large APKs (i.e. in FindTransfers()).
-              assert not self.disable_imgdiff
-              imgdiff = True
+            if xf.patch_info:
+              # We have already generated the patch (e.g. during split of large
+              # APKs or reduction of stash size)
+              imgdiff = xf.patch_info.imgdiff
             else:
               imgdiff = self.CanUseImgdiff(
                   xf.tgt_name, xf.tgt_ranges, xf.src_ranges)
@@ -854,85 +849,16 @@
         else:
           assert False, "unknown style " + xf.style
 
-    if diff_queue:
-      if self.threads > 1:
-        logger.info("Computing patches (using %d threads)...", self.threads)
-      else:
-        logger.info("Computing patches...")
-
-      diff_total = len(diff_queue)
-      patches = [None] * diff_total
-      error_messages = []
-
-      # Using multiprocessing doesn't give additional benefits, due to the
-      # pattern of the code. The diffing work is done by subprocess.call, which
-      # already runs in a separate process (not affected much by the GIL -
-      # Global Interpreter Lock). Using multiprocess also requires either a)
-      # writing the diff input files in the main process before forking, or b)
-      # reopening the image file (SparseImage) in the worker processes. Doing
-      # neither of them further improves the performance.
-      lock = threading.Lock()
-      def diff_worker():
-        while True:
-          with lock:
-            if not diff_queue:
-              return
-            xf_index, imgdiff, patch_index = diff_queue.pop()
-            xf = self.transfers[xf_index]
-
-          patch = xf.patch
-          if not patch:
-            src_ranges = xf.src_ranges
-            tgt_ranges = xf.tgt_ranges
-
-            src_file = common.MakeTempFile(prefix="src-")
-            with open(src_file, "wb") as fd:
-              self.src.WriteRangeDataToFd(src_ranges, fd)
-
-            tgt_file = common.MakeTempFile(prefix="tgt-")
-            with open(tgt_file, "wb") as fd:
-              self.tgt.WriteRangeDataToFd(tgt_ranges, fd)
-
-            message = []
-            try:
-              patch = compute_patch(src_file, tgt_file, imgdiff)
-            except ValueError as e:
-              message.append(
-                  "Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % (
-                      "imgdiff" if imgdiff else "bsdiff",
-                      xf.tgt_name if xf.tgt_name == xf.src_name else
-                      xf.tgt_name + " (from " + xf.src_name + ")",
-                      xf.tgt_ranges, xf.src_ranges, e.message))
-            if message:
-              with lock:
-                error_messages.extend(message)
-
-          with lock:
-            patches[patch_index] = (xf_index, patch)
-
-      threads = [threading.Thread(target=diff_worker)
-                 for _ in range(self.threads)]
-      for th in threads:
-        th.start()
-      while threads:
-        threads.pop().join()
-
-      if error_messages:
-        logger.error('ERROR:')
-        logger.error('\n'.join(error_messages))
-        logger.error('\n\n\n')
-        sys.exit(1)
-    else:
-      patches = []
+    patches = self.ComputePatchesForInputList(diff_queue, False)
 
     offset = 0
     with open(prefix + ".patch.dat", "wb") as patch_fd:
-      for index, patch in patches:
+      for index, patch_info, _ in patches:
         xf = self.transfers[index]
-        xf.patch_len = len(patch)
+        xf.patch_len = len(patch_info.content)
         xf.patch_start = offset
         offset += xf.patch_len
-        patch_fd.write(patch)
+        patch_fd.write(patch_info.content)
 
         tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
         logger.info(
@@ -999,6 +925,32 @@
       for i in range(s, e):
         assert touched[i] == 1
 
+  def FindSequenceForTransfers(self):
+    """Finds a sequence for the given transfers.
+
+     The goal is to minimize the violation of order dependencies between these
+     transfers, so that fewer blocks are stashed when applying the update.
+    """
+
+    # Clear the existing dependency between transfers
+    for xf in self.transfers:
+      xf.goes_before = OrderedDict()
+      xf.goes_after = OrderedDict()
+
+      xf.stash_before = []
+      xf.use_stash = []
+
+    # Find the ordering dependencies among transfers (this is O(n^2)
+    # in the number of transfers).
+    self.GenerateDigraph()
+    # Find a sequence of transfers that satisfies as many ordering
+    # dependencies as possible (heuristically).
+    self.FindVertexSequence()
+    # Fix up the ordering dependencies that the sequence didn't
+    # satisfy.
+    self.ReverseBackwardEdges()
+    self.ImproveVertexSequence()
+
   def ImproveVertexSequence(self):
     logger.info("Improving vertex order...")
 
@@ -1248,6 +1200,105 @@
           b.goes_before[a] = size
           a.goes_after[b] = size
 
+  def ComputePatchesForInputList(self, diff_queue, compress_target):
+    """Returns a list of patch information for the input list of transfers.
+
+      Args:
+        diff_queue: a list of transfers with style 'diff'
+        compress_target: If True, compresses the target ranges of each
+            transfers; and save the size.
+
+      Returns:
+        A list of (transfer order, patch_info, compressed_size) tuples.
+    """
+
+    if not diff_queue:
+      return []
+
+    if self.threads > 1:
+      logger.info("Computing patches (using %d threads)...", self.threads)
+    else:
+      logger.info("Computing patches...")
+
+    diff_total = len(diff_queue)
+    patches = [None] * diff_total
+    error_messages = []
+
+    # Using multiprocessing doesn't give additional benefits, due to the
+    # pattern of the code. The diffing work is done by subprocess.call, which
+    # already runs in a separate process (not affected much by the GIL -
+    # Global Interpreter Lock). Using multiprocess also requires either a)
+    # writing the diff input files in the main process before forking, or b)
+    # reopening the image file (SparseImage) in the worker processes. Doing
+    # neither of them further improves the performance.
+    lock = threading.Lock()
+
+    def diff_worker():
+      while True:
+        with lock:
+          if not diff_queue:
+            return
+          xf_index, imgdiff, patch_index = diff_queue.pop()
+          xf = self.transfers[xf_index]
+
+        message = []
+        compressed_size = None
+
+        patch_info = xf.patch_info
+        if not patch_info:
+          src_file = common.MakeTempFile(prefix="src-")
+          with open(src_file, "wb") as fd:
+            self.src.WriteRangeDataToFd(xf.src_ranges, fd)
+
+          tgt_file = common.MakeTempFile(prefix="tgt-")
+          with open(tgt_file, "wb") as fd:
+            self.tgt.WriteRangeDataToFd(xf.tgt_ranges, fd)
+
+          try:
+            patch_info = compute_patch(src_file, tgt_file, imgdiff)
+          except ValueError as e:
+            message.append(
+                "Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % (
+                    "imgdiff" if imgdiff else "bsdiff",
+                    xf.tgt_name if xf.tgt_name == xf.src_name else
+                    xf.tgt_name + " (from " + xf.src_name + ")",
+                    xf.tgt_ranges, xf.src_ranges, e.message))
+
+        if compress_target:
+          tgt_data = self.tgt.ReadRangeSet(xf.tgt_ranges)
+          try:
+            # Compresses with the default level
+            compress_obj = zlib.compressobj(6, zlib.DEFLATED, -zlib.MAX_WBITS)
+            compressed_data = (compress_obj.compress("".join(tgt_data))
+                               + compress_obj.flush())
+            compressed_size = len(compressed_data)
+          except zlib.error as e:
+            message.append(
+                "Failed to compress the data in target range {} for {}:\n"
+                "{}".format(xf.tgt_ranges, xf.tgt_name, e.message))
+
+        if message:
+          with lock:
+            error_messages.extend(message)
+
+        with lock:
+          patches[patch_index] = (xf_index, patch_info, compressed_size)
+
+    threads = [threading.Thread(target=diff_worker)
+               for _ in range(self.threads)]
+    for th in threads:
+      th.start()
+    while threads:
+      threads.pop().join()
+
+    if error_messages:
+      logger.error('ERROR:')
+      logger.error('\n'.join(error_messages))
+      logger.error('\n\n\n')
+      sys.exit(1)
+
+    return patches
+
   def FindTransfers(self):
     """Parse the file_map to generate all the transfers."""
 
@@ -1585,7 +1636,7 @@
                                 self.tgt.RangeSha1(tgt_ranges),
                                 self.src.RangeSha1(src_ranges),
                                 "diff", self.transfers)
-      transfer_split.patch = patch
+      transfer_split.patch_info = PatchInfo(True, patch)
 
   def AbbreviateSourceNames(self):
     for k in self.src.file_map.keys():