update OTA script in tm-mainline-prod to the latest version.

The last change in tm-mainline-prod was last year (aosp/2040556), did a copy of current version in udc-dev.

Test: local ota
Bug: 279622634
Change-Id: I6563122f21d7213bfa7200e28cdfa69bd95aa3e2
Merged-In: Iaa317a3a4b8addbca8ea987aee9953c78fa1a679
diff --git a/scripts/update_device.py b/scripts/update_device.py
index 72cee49..f94774b 100755
--- a/scripts/update_device.py
+++ b/scripts/update_device.py
@@ -25,6 +25,7 @@
 import hashlib
 import logging
 import os
+import re
 import socket
 import subprocess
 import sys
@@ -50,7 +51,7 @@
 DEVICE_PORT = 1234
 
 
-def CopyFileObjLength(fsrc, fdst, buffer_size=128 * 1024, copy_length=None):
+def CopyFileObjLength(fsrc, fdst, buffer_size=128 * 1024, copy_length=None, speed_limit=None):
   """Copy from a file object to another.
 
   This function is similar to shutil.copyfileobj except that it allows to copy
@@ -61,10 +62,18 @@
     fdst: destination file object where to write to.
     buffer_size: size of the copy buffer in memory.
     copy_length: maximum number of bytes to copy, or None to copy everything.
+    speed_limit: upper limit for copying speed, in bytes per second.
 
   Returns:
     the number of bytes copied.
   """
+  # If buffer size significantly bigger than speed limit
+  # traffic would seem extremely spiky to the client.
+  if speed_limit:
+    print(f"Applying speed limit: {speed_limit}")
+    buffer_size = min(speed_limit//32, buffer_size)
+
+  start_time = time.time()
   copied = 0
   while True:
     chunk_size = buffer_size
@@ -75,6 +84,11 @@
     buf = fsrc.read(chunk_size)
     if not buf:
       break
+    if speed_limit:
+      expected_duration = copied/speed_limit
+      actual_duration = time.time() - start_time
+      if actual_duration < expected_duration:
+        time.sleep(expected_duration-actual_duration)
     fdst.write(buf)
     copied += len(buf)
   return copied
@@ -211,7 +225,8 @@
     self.end_headers()
 
     f.seek(serving_start + start_range)
-    CopyFileObjLength(f, self.wfile, copy_length=end_range - start_range)
+    CopyFileObjLength(f, self.wfile, copy_length=end_range -
+                      start_range, speed_limit=self.speed_limit)
 
   def do_POST(self):  # pylint: disable=invalid-name
     """Reply with the omaha response xml."""
@@ -291,12 +306,13 @@
 class ServerThread(threading.Thread):
   """A thread for serving HTTP requests."""
 
-  def __init__(self, ota_filename, serving_range):
+  def __init__(self, ota_filename, serving_range, speed_limit):
     threading.Thread.__init__(self)
     # serving_payload and serving_range are class attributes and the
     # UpdateHandler class is instantiated with every request.
     UpdateHandler.serving_payload = ota_filename
     UpdateHandler.serving_range = serving_range
+    UpdateHandler.speed_limit = speed_limit
     self._httpd = BaseHTTPServer.HTTPServer(('127.0.0.1', 0), UpdateHandler)
     self.port = self._httpd.server_port
 
@@ -312,8 +328,8 @@
     self._httpd.socket.close()
 
 
-def StartServer(ota_filename, serving_range):
-  t = ServerThread(ota_filename, serving_range)
+def StartServer(ota_filename, serving_range, speed_limit):
+  t = ServerThread(ota_filename, serving_range, speed_limit)
   t.start()
   return t
 
@@ -408,6 +424,27 @@
       ]) == 0
 
 
+def ParseSpeedLimit(arg: str) -> int:
+  arg = arg.strip().upper()
+  if not re.match(r"\d+[KkMmGgTt]?", arg):
+    raise argparse.ArgumentError(
+        "Wrong speed limit format, expected format is number followed by unit, such as 10K, 5m, 3G (case insensitive)")
+  unit = 1
+  if arg[-1].isalpha():
+    if arg[-1] == "K":
+      unit = 1024
+    elif arg[-1] == "M":
+      unit = 1024 * 1024
+    elif arg[-1] == "G":
+      unit = 1024 * 1024 * 1024
+    elif arg[-1] == "T":
+      unit = 1024 * 1024 * 1024 * 1024
+    else:
+      raise argparse.ArgumentError(
+          f"Unsupported unit for download speed: {arg[-1]}, supported units are K,M,G,T (case insensitive)")
+  return int(float(arg[:-1]) * unit)
+
+
 def main():
   parser = argparse.ArgumentParser(description='Android A/B OTA helper.')
   parser.add_argument('otafile', metavar='PAYLOAD', type=str,
@@ -444,7 +481,22 @@
                       help='Perform reset slot switch for this OTA package')
   parser.add_argument('--wipe-user-data', action='store_true',
                       help='Wipe userdata after installing OTA')
+  parser.add_argument('--vabc-none', action='store_true',
+                      help='Set Virtual AB Compression algorithm to none, but still use Android COW format')
+  parser.add_argument('--disable-vabc', action='store_true',
+                      help='Option to enable or disable vabc. If set to false, will fall back on A/B')
+  parser.add_argument('--enable-threading', action='store_true',
+                      help='Enable multi-threaded compression for VABC')
+  parser.add_argument('--batched-writes', action='store_true',
+                      help='Enable batched writes for VABC')
+  parser.add_argument('--speed-limit', type=str,
+                      help='Speed limit for serving payloads over HTTP. For '
+                      'example: 10K, 5m, 1G, input is case insensitive')
+
   args = parser.parse_args()
+  if args.speed_limit:
+    args.speed_limit = ParseSpeedLimit(args.speed_limit)
+
   logging.basicConfig(
       level=logging.WARNING if args.no_verbose else logging.INFO)
 
@@ -497,6 +549,14 @@
     args.extra_headers += "\nRUN_POST_INSTALL=0"
   if args.wipe_user_data:
     args.extra_headers += "\nPOWERWASH=1"
+  if args.vabc_none:
+    args.extra_headers += "\nVABC_NONE=1"
+  if args.disable_vabc:
+    args.extra_headers += "\nDISABLE_VABC=1"
+  if args.enable_threading:
+    args.extra_headers += "\nENABLE_THREADING=1"
+  if args.batched_writes:
+    args.extra_headers += "\nBATCHED_WRITES=1"
 
   with zipfile.ZipFile(args.otafile) as zfp:
     CARE_MAP_ENTRY_NAME = "care_map.pb"
@@ -531,7 +591,7 @@
       serving_range = (ota.offset, ota.size)
     else:
       serving_range = (0, os.stat(args.otafile).st_size)
-    server_thread = StartServer(args.otafile, serving_range)
+    server_thread = StartServer(args.otafile, serving_range, args.speed_limit)
     cmds.append(
         ['reverse', 'tcp:%d' % DEVICE_PORT, 'tcp:%d' % server_thread.port])
     finalize_cmds.append(['reverse', '--remove', 'tcp:%d' % DEVICE_PORT])