update_device.py: support OTA zip if use omaha too.

Pass the offset and size to UpdateHandler.
Also added support for payload offset in update_payload.Payload.

Test: applied a local OTA
Change-Id: Ib116ef2c23a11e298118f203814c4ea8dd1629af
diff --git a/scripts/update_device.py b/scripts/update_device.py
index ca7518d..1a0daf8 100755
--- a/scripts/update_device.py
+++ b/scripts/update_device.py
@@ -19,6 +19,7 @@
 
 import argparse
 import BaseHTTPServer
+import hashlib
 import logging
 import os
 import socket
@@ -98,6 +99,7 @@
 
   Attributes:
     serving_payload: path to the only payload file we are serving.
+    serving_range: the start offset and size tuple of the payload.
   """
 
   @staticmethod
@@ -148,12 +150,12 @@
     else:
       self.send_response(200)
 
-    stat = os.fstat(f.fileno())
+    serving_start, serving_size = self.serving_range
     start_range, end_range = self._parse_range(self.headers.get('range'),
-                                               stat.st_size)
+                                               serving_size)
     logging.info('Serving request for %s from %s [%d, %d) length: %d',
-                 self.path, self.serving_payload, start_range, end_range,
-                 end_range - start_range)
+                 self.path, self.serving_payload, serving_start + start_range,
+                 serving_start + end_range, end_range - start_range)
 
     self.send_header('Accept-Ranges', 'bytes')
     self.send_header('Content-Range',
@@ -161,11 +163,12 @@
                      '/' + str(end_range - start_range))
     self.send_header('Content-Length', end_range - start_range)
 
+    stat = os.fstat(f.fileno())
     self.send_header('Last-Modified', self.date_time_string(stat.st_mtime))
     self.send_header('Content-type', 'application/octet-stream')
     self.end_headers()
 
-    f.seek(start_range)
+    f.seek(serving_start + start_range)
     CopyFileObjLength(f, self.wfile, copy_length=end_range - start_range)
 
 
@@ -201,10 +204,19 @@
     self.send_header("Content-type", "text/xml")
     self.end_headers()
 
-    stat = os.fstat(f.fileno())
-    sha256sum = subprocess.check_output(['sha256sum', self.serving_payload])
-    payload_hash = sha256sum.split()[0]
-    payload = update_payload.Payload(f)
+    serving_start, serving_size = self.serving_range
+    sha256 = hashlib.sha256()
+    f.seek(serving_start)
+    bytes_to_hash = serving_size
+    while bytes_to_hash:
+      buf = f.read(min(bytes_to_hash, 1024 * 1024))
+      if not buf:
+        self.send_error(500, 'Payload too small')
+        return
+      sha256.update(buf)
+      bytes_to_hash -= len(buf)
+
+    payload = update_payload.Payload(f, payload_file_offset=serving_start)
     payload.Init()
 
     response_xml = '''
@@ -228,8 +240,9 @@
           </app>
         </response>
     '''.format(appid=appid, port=DEVICE_PORT,
-               metadata_size=payload.metadata_size, payload_hash=payload_hash,
-               payload_size=stat.st_size)
+               metadata_size=payload.metadata_size,
+               payload_hash=sha256.hexdigest(),
+               payload_size=serving_size)
     self.wfile.write(response_xml.strip())
     return
 
@@ -237,11 +250,12 @@
 class ServerThread(threading.Thread):
   """A thread for serving HTTP requests."""
 
-  def __init__(self, ota_filename):
+  def __init__(self, ota_filename, serving_range):
     threading.Thread.__init__(self)
-    # serving_payload is a class attribute and the UpdateHandler class is
-    # instantiated with every request.
+    # serving_payload and serving_range are class attributes and the
+    # UpdateHandler class is instantiated with every request.
     UpdateHandler.serving_payload = ota_filename
+    UpdateHandler.serving_range = serving_range
     self._httpd = BaseHTTPServer.HTTPServer(('127.0.0.1', 0), UpdateHandler)
     self.port = self._httpd.server_port
 
@@ -256,8 +270,8 @@
     self._httpd.socket.close()
 
 
-def StartServer(ota_filename):
-  t = ServerThread(ota_filename)
+def StartServer(ota_filename, serving_range):
+  t = ServerThread(ota_filename, serving_range)
   t.start()
   return t
 
@@ -333,8 +347,9 @@
 
 def main():
   parser = argparse.ArgumentParser(description='Android A/B OTA helper.')
-  parser.add_argument('otafile', metavar='ZIP', type=str,
-                      help='the OTA package file (a .zip file).')
+  parser.add_argument('otafile', metavar='PAYLOAD', type=str,
+                      help='the OTA package file (a .zip file) or raw payload \
+                      if device uses Omaha.')
   parser.add_argument('--file', action='store_true',
                       help='Push the file to the device before updating.')
   parser.add_argument('--no-push', action='store_true',
@@ -374,7 +389,12 @@
     # Update via sending the payload over the network with an "adb reverse"
     # command.
     payload_url = 'http://127.0.0.1:%d/payload' % DEVICE_PORT
-    server_thread = StartServer(args.otafile)
+    if use_omaha and zipfile.is_zipfile(args.otafile):
+      ota = AndroidOTAPackage(args.otafile)
+      serving_range = (ota.offset, ota.size)
+    else:
+      serving_range = (0, os.stat(args.otafile).st_size)
+    server_thread = StartServer(args.otafile, serving_range)
     cmds.append(
         ['reverse', 'tcp:%d' % DEVICE_PORT, 'tcp:%d' % server_thread.port])
     finalize_cmds.append(['reverse', '--remove', 'tcp:%d' % DEVICE_PORT])
diff --git a/scripts/update_payload/payload.py b/scripts/update_payload/payload.py
index d1a99ec..184f805 100644
--- a/scripts/update_payload/payload.py
+++ b/scripts/update_payload/payload.py
@@ -101,13 +101,15 @@
             hasher=hasher)
 
 
-  def __init__(self, payload_file):
+  def __init__(self, payload_file, payload_file_offset=0):
     """Initialize the payload object.
 
     Args:
       payload_file: update payload file object open for reading
+      payload_file_offset: the offset of the actual payload
     """
     self.payload_file = payload_file
+    self.payload_file_offset = payload_file_offset
     self.manifest_hasher = None
     self.is_init = False
     self.header = None
@@ -159,7 +161,8 @@
 
     return common.Read(
         self.payload_file, self.header.metadata_signature_len,
-        offset=self.header.size + self.header.manifest_len)
+        offset=self.payload_file_offset + self.header.size +
+        self.header.manifest_len)
 
   def ReadDataBlob(self, offset, length):
     """Reads and returns a single data blob from the update payload.
@@ -175,7 +178,8 @@
       PayloadError if a read error occurred.
     """
     return common.Read(self.payload_file, length,
-                       offset=self.data_offset + offset)
+                       offset=self.payload_file_offset + self.data_offset +
+                       offset)
 
   def Init(self):
     """Initializes the payload object.
@@ -194,6 +198,7 @@
     self.manifest_hasher = hashlib.sha256()
 
     # Read the file header.
+    self.payload_file.seek(self.payload_file_offset)
     self.header = self._ReadHeader()
 
     # Read the manifest.
@@ -246,7 +251,7 @@
 
   def ResetFile(self):
     """Resets the offset of the payload file to right past the manifest."""
-    self.payload_file.seek(self.data_offset)
+    self.payload_file.seek(self.payload_file_offset + self.data_offset)
 
   def IsDelta(self):
     """Returns True iff the payload appears to be a delta."""