Add a speed limit feature to simulate slow network

During development, sometimes it is beneficial to understand
update_engine's performance under limited network bandwidth. Add a
--speed-limit parameter to update_device.py to simulate slow network
connection. --speed-limit 100K would impose an upper bound on download
speed of 100KB/s. Accepted units: K,M,G,T

Test: th
Change-Id: I5947be1534866ec53d90696ce303062a644138ff
diff --git a/scripts/update_device.py b/scripts/update_device.py
index 950ff3d..7cf66a5 100755
--- a/scripts/update_device.py
+++ b/scripts/update_device.py
@@ -25,6 +25,7 @@
 import hashlib
 import logging
 import os
+import re
 import socket
 import subprocess
 import sys
@@ -50,7 +51,7 @@
 DEVICE_PORT = 1234
 
 
-def CopyFileObjLength(fsrc, fdst, buffer_size=128 * 1024, copy_length=None):
+def CopyFileObjLength(fsrc, fdst, buffer_size=128 * 1024, copy_length=None, speed_limit=None):
   """Copy from a file object to another.
 
   This function is similar to shutil.copyfileobj except that it allows to copy
@@ -61,10 +62,18 @@
     fdst: destination file object where to write to.
     buffer_size: size of the copy buffer in memory.
     copy_length: maximum number of bytes to copy, or None to copy everything.
+    speed_limit: upper limit for copying speed, in bytes per second.
 
   Returns:
     the number of bytes copied.
   """
+  # If buffer size significantly bigger than speed limit
+  # traffic would seem extremely spiky to the client.
+  if speed_limit:
+    print(f"Applying speed limit: {speed_limit}")
+    buffer_size = min(speed_limit//32, buffer_size)
+
+  start_time = time.time()
   copied = 0
   while True:
     chunk_size = buffer_size
@@ -75,6 +84,11 @@
     buf = fsrc.read(chunk_size)
     if not buf:
       break
+    if speed_limit:
+      expected_duration = copied/speed_limit
+      actual_duration = time.time() - start_time
+      if actual_duration < expected_duration:
+        time.sleep(expected_duration-actual_duration)
     fdst.write(buf)
     copied += len(buf)
   return copied
@@ -211,7 +225,8 @@
     self.end_headers()
 
     f.seek(serving_start + start_range)
-    CopyFileObjLength(f, self.wfile, copy_length=end_range - start_range)
+    CopyFileObjLength(f, self.wfile, copy_length=end_range -
+                      start_range, speed_limit=self.speed_limit)
 
   def do_POST(self):  # pylint: disable=invalid-name
     """Reply with the omaha response xml."""
@@ -291,12 +306,13 @@
 class ServerThread(threading.Thread):
   """A thread for serving HTTP requests."""
 
-  def __init__(self, ota_filename, serving_range):
+  def __init__(self, ota_filename, serving_range, speed_limit):
     threading.Thread.__init__(self)
     # serving_payload and serving_range are class attributes and the
     # UpdateHandler class is instantiated with every request.
     UpdateHandler.serving_payload = ota_filename
     UpdateHandler.serving_range = serving_range
+    UpdateHandler.speed_limit = speed_limit
     self._httpd = BaseHTTPServer.HTTPServer(('127.0.0.1', 0), UpdateHandler)
     self.port = self._httpd.server_port
 
@@ -312,8 +328,8 @@
     self._httpd.socket.close()
 
 
-def StartServer(ota_filename, serving_range):
-  t = ServerThread(ota_filename, serving_range)
+def StartServer(ota_filename, serving_range, speed_limit):
+  t = ServerThread(ota_filename, serving_range, speed_limit)
   t.start()
   return t
 
@@ -408,6 +424,27 @@
       ]) == 0
 
 
+def ParseSpeedLimit(arg: str) -> int:
+  arg = arg.strip().upper()
+  if not re.match(r"\d+[KkMmGgTt]?", arg):
+    raise argparse.ArgumentError(
+        "Wrong speed limit format, expected format is number followed by unit, such as 10K, 5m, 3G (case insensitive)")
+  unit = 1
+  if arg[-1].isalpha():
+    if arg[-1] == "K":
+      unit = 1024
+    elif arg[-1] == "M":
+      unit = 1024 * 1024
+    elif arg[-1] == "G":
+      unit = 1024 * 1024 * 1024
+    elif arg[-1] == "T":
+      unit = 1024 * 1024 * 1024 * 1024
+    else:
+      raise argparse.ArgumentError(
+          f"Unsupported unit for download speed: {arg[-1]}, supported units are K,M,G,T (case insensitive)")
+  return int(float(arg[:-1]) * unit)
+
+
 def main():
   parser = argparse.ArgumentParser(description='Android A/B OTA helper.')
   parser.add_argument('otafile', metavar='PAYLOAD', type=str,
@@ -446,7 +483,14 @@
                       help='Wipe userdata after installing OTA')
   parser.add_argument('--disable-vabc', action='store_true',
                       help='Disable vabc during OTA')
+  parser.add_argument('--speed-limit', type=str,
+                      help='Speed limit for serving payloads over HTTP. For '
+                      'example: 10K, 5m, 1G, input is case insensitive')
+
   args = parser.parse_args()
+  if args.speed_limit:
+    args.speed_limit = ParseSpeedLimit(args.speed_limit)
+
   logging.basicConfig(
       level=logging.WARNING if args.no_verbose else logging.INFO)
 
@@ -535,7 +579,7 @@
       serving_range = (ota.offset, ota.size)
     else:
       serving_range = (0, os.stat(args.otafile).st_size)
-    server_thread = StartServer(args.otafile, serving_range)
+    server_thread = StartServer(args.otafile, serving_range, args.speed_limit)
     cmds.append(
         ['reverse', 'tcp:%d' % DEVICE_PORT, 'tcp:%d' % server_thread.port])
     finalize_cmds.append(['reverse', '--remove', 'tcp:%d' % DEVICE_PORT])