update_device.py: support OTA zip if use omaha too.

Pass the offset and size to UpdateHandler.
Also added support for payload offset in update_payload.Payload.

Test: applied a local OTA
Change-Id: Ib116ef2c23a11e298118f203814c4ea8dd1629af
diff --git a/scripts/update_device.py b/scripts/update_device.py
index ca7518d..1a0daf8 100755
--- a/scripts/update_device.py
+++ b/scripts/update_device.py
@@ -19,6 +19,7 @@
 
 import argparse
 import BaseHTTPServer
+import hashlib
 import logging
 import os
 import socket
@@ -98,6 +99,7 @@
 
   Attributes:
     serving_payload: path to the only payload file we are serving.
+    serving_range: the start offset and size tuple of the payload.
   """
 
   @staticmethod
@@ -148,12 +150,12 @@
     else:
       self.send_response(200)
 
-    stat = os.fstat(f.fileno())
+    serving_start, serving_size = self.serving_range
     start_range, end_range = self._parse_range(self.headers.get('range'),
-                                               stat.st_size)
+                                               serving_size)
     logging.info('Serving request for %s from %s [%d, %d) length: %d',
-                 self.path, self.serving_payload, start_range, end_range,
-                 end_range - start_range)
+                 self.path, self.serving_payload, serving_start + start_range,
+                 serving_start + end_range, end_range - start_range)
 
     self.send_header('Accept-Ranges', 'bytes')
     self.send_header('Content-Range',
@@ -161,11 +163,12 @@
                      '/' + str(end_range - start_range))
     self.send_header('Content-Length', end_range - start_range)
 
+    stat = os.fstat(f.fileno())
     self.send_header('Last-Modified', self.date_time_string(stat.st_mtime))
     self.send_header('Content-type', 'application/octet-stream')
     self.end_headers()
 
-    f.seek(start_range)
+    f.seek(serving_start + start_range)
     CopyFileObjLength(f, self.wfile, copy_length=end_range - start_range)
 
 
@@ -201,10 +204,19 @@
     self.send_header("Content-type", "text/xml")
     self.end_headers()
 
-    stat = os.fstat(f.fileno())
-    sha256sum = subprocess.check_output(['sha256sum', self.serving_payload])
-    payload_hash = sha256sum.split()[0]
-    payload = update_payload.Payload(f)
+    serving_start, serving_size = self.serving_range
+    sha256 = hashlib.sha256()
+    f.seek(serving_start)
+    bytes_to_hash = serving_size
+    while bytes_to_hash:
+      buf = f.read(min(bytes_to_hash, 1024 * 1024))
+      if not buf:
+        self.send_error(500, 'Payload too small')
+        return
+      sha256.update(buf)
+      bytes_to_hash -= len(buf)
+
+    payload = update_payload.Payload(f, payload_file_offset=serving_start)
     payload.Init()
 
     response_xml = '''
@@ -228,8 +240,9 @@
           </app>
         </response>
     '''.format(appid=appid, port=DEVICE_PORT,
-               metadata_size=payload.metadata_size, payload_hash=payload_hash,
-               payload_size=stat.st_size)
+               metadata_size=payload.metadata_size,
+               payload_hash=sha256.hexdigest(),
+               payload_size=serving_size)
     self.wfile.write(response_xml.strip())
     return
 
@@ -237,11 +250,12 @@
 class ServerThread(threading.Thread):
   """A thread for serving HTTP requests."""
 
-  def __init__(self, ota_filename):
+  def __init__(self, ota_filename, serving_range):
     threading.Thread.__init__(self)
-    # serving_payload is a class attribute and the UpdateHandler class is
-    # instantiated with every request.
+    # serving_payload and serving_range are class attributes and the
+    # UpdateHandler class is instantiated with every request.
     UpdateHandler.serving_payload = ota_filename
+    UpdateHandler.serving_range = serving_range
     self._httpd = BaseHTTPServer.HTTPServer(('127.0.0.1', 0), UpdateHandler)
     self.port = self._httpd.server_port
 
@@ -256,8 +270,8 @@
     self._httpd.socket.close()
 
 
-def StartServer(ota_filename):
-  t = ServerThread(ota_filename)
+def StartServer(ota_filename, serving_range):
+  t = ServerThread(ota_filename, serving_range)
   t.start()
   return t
 
@@ -333,8 +347,9 @@
 
 def main():
   parser = argparse.ArgumentParser(description='Android A/B OTA helper.')
-  parser.add_argument('otafile', metavar='ZIP', type=str,
-                      help='the OTA package file (a .zip file).')
+  parser.add_argument('otafile', metavar='PAYLOAD', type=str,
+                      help='the OTA package file (a .zip file) or raw payload \
+                      if device uses Omaha.')
   parser.add_argument('--file', action='store_true',
                       help='Push the file to the device before updating.')
   parser.add_argument('--no-push', action='store_true',
@@ -374,7 +389,12 @@
     # Update via sending the payload over the network with an "adb reverse"
     # command.
     payload_url = 'http://127.0.0.1:%d/payload' % DEVICE_PORT
-    server_thread = StartServer(args.otafile)
+    if use_omaha and zipfile.is_zipfile(args.otafile):
+      ota = AndroidOTAPackage(args.otafile)
+      serving_range = (ota.offset, ota.size)
+    else:
+      serving_range = (0, os.stat(args.otafile).st_size)
+    server_thread = StartServer(args.otafile, serving_range)
     cmds.append(
         ['reverse', 'tcp:%d' % DEVICE_PORT, 'tcp:%d' % server_thread.port])
     finalize_cmds.append(['reverse', '--remove', 'tcp:%d' % DEVICE_PORT])