blockimgdiff: Factor out the diff_worker
We will call it at an earlier time to compute the patch size; and
choose the transfers to convert to 'new'.
Bug: 120561199
Test: Generate an incremental update on shiner
Change-Id: I29a0c8e75c9e5b66a266c1387186692a86fcbe43
diff --git a/tools/releasetools/blockimgdiff.py b/tools/releasetools/blockimgdiff.py
index 2d20e23..b5e01d3 100644
--- a/tools/releasetools/blockimgdiff.py
+++ b/tools/releasetools/blockimgdiff.py
@@ -26,7 +26,8 @@
import re
import sys
import threading
-from collections import deque, OrderedDict
+import zlib
+from collections import deque, namedtuple, OrderedDict
from hashlib import sha1
import common
@@ -36,8 +37,12 @@
logger = logging.getLogger(__name__)
+# The tuple contains the style and bytes of a bsdiff|imgdiff patch.
+PatchInfo = namedtuple("PatchInfo", ["imgdiff", "content"])
+
def compute_patch(srcfile, tgtfile, imgdiff=False):
+ """Calls bsdiff|imgdiff to compute the patch data, returns a PatchInfo."""
patchfile = common.MakeTempFile(prefix='patch-')
cmd = ['imgdiff', '-z'] if imgdiff else ['bsdiff']
@@ -52,7 +57,7 @@
raise ValueError(output)
with open(patchfile, 'rb') as f:
- return f.read()
+ return PatchInfo(imgdiff, f.read())
class Image(object):
@@ -203,17 +208,17 @@
self.id = len(by_id)
by_id.append(self)
- self._patch = None
+ self._patch_info = None
@property
- def patch(self):
- return self._patch
+ def patch_info(self):
+ return self._patch_info
- @patch.setter
- def patch(self, patch):
- if patch:
+ @patch_info.setter
+ def patch_info(self, info):
+ if info:
assert self.style == "diff"
- self._patch = patch
+ self._patch_info = info
def NetStashChange(self):
return (sum(sr.size() for (_, sr) in self.stash_before) -
@@ -224,7 +229,7 @@
self.use_stash = []
self.style = "new"
self.src_ranges = RangeSet()
- self.patch = None
+ self.patch_info = None
def __str__(self):
return (str(self.id) + ": <" + str(self.src_ranges) + " " + self.style +
@@ -462,16 +467,7 @@
self.AbbreviateSourceNames()
self.FindTransfers()
- # Find the ordering dependencies among transfers (this is O(n^2)
- # in the number of transfers).
- self.GenerateDigraph()
- # Find a sequence of transfers that satisfies as many ordering
- # dependencies as possible (heuristically).
- self.FindVertexSequence()
- # Fix up the ordering dependencies that the sequence didn't
- # satisfy.
- self.ReverseBackwardEdges()
- self.ImproveVertexSequence()
+ self.FindSequenceForTransfers()
# Ensure the runtime stash size is under the limit.
if common.OPTIONS.cache_size is not None:
@@ -829,7 +825,7 @@
# These are identical; we don't need to generate a patch,
# just issue copy commands on the device.
xf.style = "move"
- xf.patch = None
+ xf.patch_info = None
tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
if xf.src_ranges != xf.tgt_ranges:
logger.info(
@@ -839,11 +835,10 @@
xf.tgt_name + " (from " + xf.src_name + ")"),
str(xf.tgt_ranges), str(xf.src_ranges))
else:
- if xf.patch:
- # We have already generated the patch with imgdiff, while
- # splitting large APKs (i.e. in FindTransfers()).
- assert not self.disable_imgdiff
- imgdiff = True
+ if xf.patch_info:
+ # We have already generated the patch (e.g. during split of large
+ # APKs or reduction of stash size)
+ imgdiff = xf.patch_info.imgdiff
else:
imgdiff = self.CanUseImgdiff(
xf.tgt_name, xf.tgt_ranges, xf.src_ranges)
@@ -854,85 +849,16 @@
else:
assert False, "unknown style " + xf.style
- if diff_queue:
- if self.threads > 1:
- logger.info("Computing patches (using %d threads)...", self.threads)
- else:
- logger.info("Computing patches...")
-
- diff_total = len(diff_queue)
- patches = [None] * diff_total
- error_messages = []
-
- # Using multiprocessing doesn't give additional benefits, due to the
- # pattern of the code. The diffing work is done by subprocess.call, which
- # already runs in a separate process (not affected much by the GIL -
- # Global Interpreter Lock). Using multiprocess also requires either a)
- # writing the diff input files in the main process before forking, or b)
- # reopening the image file (SparseImage) in the worker processes. Doing
- # neither of them further improves the performance.
- lock = threading.Lock()
- def diff_worker():
- while True:
- with lock:
- if not diff_queue:
- return
- xf_index, imgdiff, patch_index = diff_queue.pop()
- xf = self.transfers[xf_index]
-
- patch = xf.patch
- if not patch:
- src_ranges = xf.src_ranges
- tgt_ranges = xf.tgt_ranges
-
- src_file = common.MakeTempFile(prefix="src-")
- with open(src_file, "wb") as fd:
- self.src.WriteRangeDataToFd(src_ranges, fd)
-
- tgt_file = common.MakeTempFile(prefix="tgt-")
- with open(tgt_file, "wb") as fd:
- self.tgt.WriteRangeDataToFd(tgt_ranges, fd)
-
- message = []
- try:
- patch = compute_patch(src_file, tgt_file, imgdiff)
- except ValueError as e:
- message.append(
- "Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % (
- "imgdiff" if imgdiff else "bsdiff",
- xf.tgt_name if xf.tgt_name == xf.src_name else
- xf.tgt_name + " (from " + xf.src_name + ")",
- xf.tgt_ranges, xf.src_ranges, e.message))
- if message:
- with lock:
- error_messages.extend(message)
-
- with lock:
- patches[patch_index] = (xf_index, patch)
-
- threads = [threading.Thread(target=diff_worker)
- for _ in range(self.threads)]
- for th in threads:
- th.start()
- while threads:
- threads.pop().join()
-
- if error_messages:
- logger.error('ERROR:')
- logger.error('\n'.join(error_messages))
- logger.error('\n\n\n')
- sys.exit(1)
- else:
- patches = []
+ patches = self.ComputePatchesForInputList(diff_queue, False)
offset = 0
with open(prefix + ".patch.dat", "wb") as patch_fd:
- for index, patch in patches:
+ for index, patch_info, _ in patches:
xf = self.transfers[index]
- xf.patch_len = len(patch)
+ xf.patch_len = len(patch_info.content)
xf.patch_start = offset
offset += xf.patch_len
- patch_fd.write(patch)
+ patch_fd.write(patch_info.content)
tgt_size = xf.tgt_ranges.size() * self.tgt.blocksize
logger.info(
@@ -999,6 +925,32 @@
for i in range(s, e):
assert touched[i] == 1
+ def FindSequenceForTransfers(self):
+ """Finds a sequence for the given transfers.
+
+ The goal is to minimize the violation of order dependencies between these
+ transfers, so that fewer blocks are stashed when applying the update.
+ """
+
+ # Clear the existing dependency between transfers
+ for xf in self.transfers:
+ xf.goes_before = OrderedDict()
+ xf.goes_after = OrderedDict()
+
+ xf.stash_before = []
+ xf.use_stash = []
+
+ # Find the ordering dependencies among transfers (this is O(n^2)
+ # in the number of transfers).
+ self.GenerateDigraph()
+ # Find a sequence of transfers that satisfies as many ordering
+ # dependencies as possible (heuristically).
+ self.FindVertexSequence()
+ # Fix up the ordering dependencies that the sequence didn't
+ # satisfy.
+ self.ReverseBackwardEdges()
+ self.ImproveVertexSequence()
+
def ImproveVertexSequence(self):
logger.info("Improving vertex order...")
@@ -1248,6 +1200,105 @@
b.goes_before[a] = size
a.goes_after[b] = size
+ def ComputePatchesForInputList(self, diff_queue, compress_target):
+ """Returns a list of patch information for the input list of transfers.
+
+ Args:
+ diff_queue: a list of transfers with style 'diff'
+ compress_target: If True, compresses the target ranges of each
+ transfers; and save the size.
+
+ Returns:
+ A list of (transfer order, patch_info, compressed_size) tuples.
+ """
+
+ if not diff_queue:
+ return []
+
+ if self.threads > 1:
+ logger.info("Computing patches (using %d threads)...", self.threads)
+ else:
+ logger.info("Computing patches...")
+
+ diff_total = len(diff_queue)
+ patches = [None] * diff_total
+ error_messages = []
+
+ # Using multiprocessing doesn't give additional benefits, due to the
+ # pattern of the code. The diffing work is done by subprocess.call, which
+ # already runs in a separate process (not affected much by the GIL -
+ # Global Interpreter Lock). Using multiprocess also requires either a)
+ # writing the diff input files in the main process before forking, or b)
+ # reopening the image file (SparseImage) in the worker processes. Doing
+ # neither of them further improves the performance.
+ lock = threading.Lock()
+
+ def diff_worker():
+ while True:
+ with lock:
+ if not diff_queue:
+ return
+ xf_index, imgdiff, patch_index = diff_queue.pop()
+ xf = self.transfers[xf_index]
+
+ message = []
+ compressed_size = None
+
+ patch_info = xf.patch_info
+ if not patch_info:
+ src_file = common.MakeTempFile(prefix="src-")
+ with open(src_file, "wb") as fd:
+ self.src.WriteRangeDataToFd(xf.src_ranges, fd)
+
+ tgt_file = common.MakeTempFile(prefix="tgt-")
+ with open(tgt_file, "wb") as fd:
+ self.tgt.WriteRangeDataToFd(xf.tgt_ranges, fd)
+
+ try:
+ patch_info = compute_patch(src_file, tgt_file, imgdiff)
+ except ValueError as e:
+ message.append(
+ "Failed to generate %s for %s: tgt=%s, src=%s:\n%s" % (
+ "imgdiff" if imgdiff else "bsdiff",
+ xf.tgt_name if xf.tgt_name == xf.src_name else
+ xf.tgt_name + " (from " + xf.src_name + ")",
+ xf.tgt_ranges, xf.src_ranges, e.message))
+
+ if compress_target:
+ tgt_data = self.tgt.ReadRangeSet(xf.tgt_ranges)
+ try:
+ # Compresses with the default level
+ compress_obj = zlib.compressobj(6, zlib.DEFLATED, -zlib.MAX_WBITS)
+ compressed_data = (compress_obj.compress("".join(tgt_data))
+ + compress_obj.flush())
+ compressed_size = len(compressed_data)
+ except zlib.error as e:
+ message.append(
+ "Failed to compress the data in target range {} for {}:\n"
+ "{}".format(xf.tgt_ranges, xf.tgt_name, e.message))
+
+ if message:
+ with lock:
+ error_messages.extend(message)
+
+ with lock:
+ patches[patch_index] = (xf_index, patch_info, compressed_size)
+
+ threads = [threading.Thread(target=diff_worker)
+ for _ in range(self.threads)]
+ for th in threads:
+ th.start()
+ while threads:
+ threads.pop().join()
+
+ if error_messages:
+ logger.error('ERROR:')
+ logger.error('\n'.join(error_messages))
+ logger.error('\n\n\n')
+ sys.exit(1)
+
+ return patches
+
def FindTransfers(self):
"""Parse the file_map to generate all the transfers."""
@@ -1585,7 +1636,7 @@
self.tgt.RangeSha1(tgt_ranges),
self.src.RangeSha1(src_ranges),
"diff", self.transfers)
- transfer_split.patch = patch
+ transfer_split.patch_info = PatchInfo(True, patch)
def AbbreviateSourceNames(self):
for k in self.src.file_map.keys():