Tianjie Xu | 41976c7 | 2019-07-03 13:57:01 -0700 | [diff] [blame] | 1 | # Copyright (C) 2019 The Android Open Source Project |
| 2 | # |
| 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | # you may not use this file except in compliance with the License. |
| 5 | # You may obtain a copy of the License at |
| 6 | # |
| 7 | # http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | # |
| 9 | # Unless required by applicable law or agreed to in writing, software |
| 10 | # distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | # See the License for the specific |
| 13 | |
| 14 | import os |
| 15 | import threading |
| 16 | from hashlib import sha1 |
| 17 | |
| 18 | from rangelib import RangeSet |
| 19 | |
| 20 | __all__ = ["EmptyImage", "DataImage", "FileImage"] |
| 21 | |
| 22 | |
| 23 | class Image(object): |
| 24 | def RangeSha1(self, ranges): |
| 25 | raise NotImplementedError |
| 26 | |
| 27 | def ReadRangeSet(self, ranges): |
| 28 | raise NotImplementedError |
| 29 | |
| 30 | def TotalSha1(self, include_clobbered_blocks=False): |
| 31 | raise NotImplementedError |
| 32 | |
| 33 | def WriteRangeDataToFd(self, ranges, fd): |
| 34 | raise NotImplementedError |
| 35 | |
| 36 | |
| 37 | class EmptyImage(Image): |
| 38 | """A zero-length image.""" |
| 39 | |
| 40 | def __init__(self): |
| 41 | self.blocksize = 4096 |
| 42 | self.care_map = RangeSet() |
| 43 | self.clobbered_blocks = RangeSet() |
| 44 | self.extended = RangeSet() |
| 45 | self.total_blocks = 0 |
| 46 | self.file_map = {} |
| 47 | self.hashtree_info = None |
| 48 | |
| 49 | def RangeSha1(self, ranges): |
| 50 | return sha1().hexdigest() |
| 51 | |
| 52 | def ReadRangeSet(self, ranges): |
| 53 | return () |
| 54 | |
| 55 | def TotalSha1(self, include_clobbered_blocks=False): |
| 56 | # EmptyImage always carries empty clobbered_blocks, so |
| 57 | # include_clobbered_blocks can be ignored. |
| 58 | assert self.clobbered_blocks.size() == 0 |
| 59 | return sha1().hexdigest() |
| 60 | |
| 61 | def WriteRangeDataToFd(self, ranges, fd): |
| 62 | raise ValueError("Can't write data from EmptyImage to file") |
| 63 | |
| 64 | |
| 65 | class DataImage(Image): |
| 66 | """An image wrapped around a single string of data.""" |
| 67 | |
| 68 | def __init__(self, data, trim=False, pad=False): |
| 69 | self.data = data |
| 70 | self.blocksize = 4096 |
| 71 | |
| 72 | assert not (trim and pad) |
| 73 | |
| 74 | partial = len(self.data) % self.blocksize |
| 75 | padded = False |
| 76 | if partial > 0: |
| 77 | if trim: |
| 78 | self.data = self.data[:-partial] |
| 79 | elif pad: |
| 80 | self.data += '\0' * (self.blocksize - partial) |
| 81 | padded = True |
| 82 | else: |
| 83 | raise ValueError(("data for DataImage must be multiple of %d bytes " |
| 84 | "unless trim or pad is specified") % |
| 85 | (self.blocksize,)) |
| 86 | |
| 87 | assert len(self.data) % self.blocksize == 0 |
| 88 | |
| 89 | self.total_blocks = len(self.data) // self.blocksize |
| 90 | self.care_map = RangeSet(data=(0, self.total_blocks)) |
| 91 | # When the last block is padded, we always write the whole block even for |
| 92 | # incremental OTAs. Because otherwise the last block may get skipped if |
| 93 | # unchanged for an incremental, but would fail the post-install |
| 94 | # verification if it has non-zero contents in the padding bytes. |
| 95 | # Bug: 23828506 |
| 96 | if padded: |
| 97 | clobbered_blocks = [self.total_blocks-1, self.total_blocks] |
| 98 | else: |
| 99 | clobbered_blocks = [] |
| 100 | self.clobbered_blocks = clobbered_blocks |
| 101 | self.extended = RangeSet() |
| 102 | |
| 103 | zero_blocks = [] |
| 104 | nonzero_blocks = [] |
| 105 | reference = '\0' * self.blocksize |
| 106 | |
| 107 | for i in range(self.total_blocks-1 if padded else self.total_blocks): |
| 108 | d = self.data[i*self.blocksize : (i+1)*self.blocksize] |
| 109 | if d == reference: |
| 110 | zero_blocks.append(i) |
| 111 | zero_blocks.append(i+1) |
| 112 | else: |
| 113 | nonzero_blocks.append(i) |
| 114 | nonzero_blocks.append(i+1) |
| 115 | |
| 116 | assert zero_blocks or nonzero_blocks or clobbered_blocks |
| 117 | |
| 118 | self.file_map = dict() |
| 119 | if zero_blocks: |
| 120 | self.file_map["__ZERO"] = RangeSet(data=zero_blocks) |
| 121 | if nonzero_blocks: |
| 122 | self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks) |
| 123 | if clobbered_blocks: |
| 124 | self.file_map["__COPY"] = RangeSet(data=clobbered_blocks) |
| 125 | |
| 126 | def _GetRangeData(self, ranges): |
| 127 | for s, e in ranges: |
| 128 | yield self.data[s*self.blocksize:e*self.blocksize] |
| 129 | |
| 130 | def RangeSha1(self, ranges): |
| 131 | h = sha1() |
| 132 | for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable |
| 133 | h.update(data) |
| 134 | return h.hexdigest() |
| 135 | |
| 136 | def ReadRangeSet(self, ranges): |
| 137 | return list(self._GetRangeData(ranges)) |
| 138 | |
| 139 | def TotalSha1(self, include_clobbered_blocks=False): |
| 140 | if not include_clobbered_blocks: |
| 141 | return self.RangeSha1(self.care_map.subtract(self.clobbered_blocks)) |
| 142 | return sha1(self.data).hexdigest() |
| 143 | |
| 144 | def WriteRangeDataToFd(self, ranges, fd): |
| 145 | for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable |
| 146 | fd.write(data) |
| 147 | |
| 148 | |
| 149 | class FileImage(Image): |
| 150 | """An image wrapped around a raw image file.""" |
| 151 | |
| 152 | def __init__(self, path, hashtree_info_generator=None): |
| 153 | self.path = path |
| 154 | self.blocksize = 4096 |
| 155 | self._file_size = os.path.getsize(self.path) |
| 156 | self._file = open(self.path, 'rb') |
| 157 | |
| 158 | if self._file_size % self.blocksize != 0: |
| 159 | raise ValueError("Size of file %s must be multiple of %d bytes, but is %d" |
| 160 | % self.path, self.blocksize, self._file_size) |
| 161 | |
| 162 | self.total_blocks = self._file_size // self.blocksize |
| 163 | self.care_map = RangeSet(data=(0, self.total_blocks)) |
| 164 | self.clobbered_blocks = RangeSet() |
| 165 | self.extended = RangeSet() |
| 166 | |
| 167 | self.generator_lock = threading.Lock() |
| 168 | |
| 169 | self.hashtree_info = None |
| 170 | if hashtree_info_generator: |
| 171 | self.hashtree_info = hashtree_info_generator.Generate(self) |
| 172 | |
| 173 | zero_blocks = [] |
| 174 | nonzero_blocks = [] |
| 175 | reference = '\0' * self.blocksize |
| 176 | |
| 177 | for i in range(self.total_blocks): |
| 178 | d = self._file.read(self.blocksize) |
| 179 | if d == reference: |
| 180 | zero_blocks.append(i) |
| 181 | zero_blocks.append(i+1) |
| 182 | else: |
| 183 | nonzero_blocks.append(i) |
| 184 | nonzero_blocks.append(i+1) |
| 185 | |
| 186 | assert zero_blocks or nonzero_blocks |
| 187 | |
| 188 | self.file_map = {} |
| 189 | if zero_blocks: |
| 190 | self.file_map["__ZERO"] = RangeSet(data=zero_blocks) |
| 191 | if nonzero_blocks: |
| 192 | self.file_map["__NONZERO"] = RangeSet(data=nonzero_blocks) |
| 193 | if self.hashtree_info: |
| 194 | self.file_map["__HASHTREE"] = self.hashtree_info.hashtree_range |
| 195 | |
| 196 | def __del__(self): |
| 197 | self._file.close() |
| 198 | |
| 199 | def _GetRangeData(self, ranges): |
| 200 | # Use a lock to protect the generator so that we will not run two |
| 201 | # instances of this generator on the same object simultaneously. |
| 202 | with self.generator_lock: |
| 203 | for s, e in ranges: |
| 204 | self._file.seek(s * self.blocksize) |
| 205 | for _ in range(s, e): |
| 206 | yield self._file.read(self.blocksize) |
| 207 | |
| 208 | def RangeSha1(self, ranges): |
| 209 | h = sha1() |
| 210 | for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable |
| 211 | h.update(data) |
| 212 | return h.hexdigest() |
| 213 | |
| 214 | def ReadRangeSet(self, ranges): |
| 215 | return list(self._GetRangeData(ranges)) |
| 216 | |
| 217 | def TotalSha1(self, include_clobbered_blocks=False): |
| 218 | assert not self.clobbered_blocks |
| 219 | return self.RangeSha1(self.care_map) |
| 220 | |
| 221 | def WriteRangeDataToFd(self, ranges, fd): |
| 222 | for data in self._GetRangeData(ranges): # pylint: disable=not-an-iterable |
| 223 | fd.write(data) |