fastboot: Introduce connect / disconnect for network-connected devices

Use introduced FileLock and network-connected devices storage entities
to introduce fastboot connect / disconnect commands

Test: everything works like discussed here go/fastboot-connect-disconnect on windows/linux
Bug: 267507450
Bug: 267506875
Change-Id: I2d6495ad567a3ddadd471a89b82d78c8c36a3d52
Signed-off-by: Dmitrii Merkurev <dimorinny@google.com>
diff --git a/fastboot/fastboot.cpp b/fastboot/fastboot.cpp
index e739404..0c8747c 100644
--- a/fastboot/fastboot.cpp
+++ b/fastboot/fastboot.cpp
@@ -73,6 +73,7 @@
 #include "diagnose_usb.h"
 #include "fastboot_driver.h"
 #include "fs.h"
+#include "storage.h"
 #include "super_flash_helper.h"
 #include "task.h"
 #include "tcp.h"
@@ -275,8 +276,36 @@
     return 0;
 }
 
-static int match_fastboot(usb_ifc_info* info) {
-    return match_fastboot_with_serial(info, serial);
+static ifc_match_func match_fastboot(const char* local_serial = serial) {
+    return [local_serial](usb_ifc_info* info) -> int {
+        return match_fastboot_with_serial(info, local_serial);
+    };
+}
+
+// output compatible with "adb devices"
+static void PrintDevice(const char* local_serial, const char* status = nullptr,
+                        const char* details = nullptr) {
+    if (local_serial == nullptr || strlen(local_serial) == 0) {
+        return;
+    }
+
+    if (g_long_listing) {
+        printf("%-22s", local_serial);
+    } else {
+        printf("%s\t", local_serial);
+    }
+
+    if (status != nullptr && strlen(status) > 0) {
+        printf(" %s", status);
+    }
+
+    if (g_long_listing) {
+        if (details != nullptr && strlen(details) > 0) {
+            printf(" %s", details);
+        }
+    }
+
+    putchar('\n');
 }
 
 static int list_devices_callback(usb_ifc_info* info) {
@@ -292,88 +321,230 @@
         if (!serial[0]) {
             serial = "????????????";
         }
-        // output compatible with "adb devices"
-        if (!g_long_listing) {
-            printf("%s\t%s", serial.c_str(), interface.c_str());
-        } else {
-            printf("%-22s %s", serial.c_str(), interface.c_str());
-            if (strlen(info->device_path) > 0) printf(" %s", info->device_path);
-        }
-        putchar('\n');
+
+        PrintDevice(serial.c_str(), interface.c_str(), info->device_path);
     }
 
     return -1;
 }
 
-// Opens a new Transport connected to a device. If |serial| is non-null it will be used to identify
-// a specific device, otherwise the first USB device found will be used.
+struct NetworkSerial {
+    Socket::Protocol protocol;
+    std::string address;
+    int port;
+};
+
+static Result<NetworkSerial> ParseNetworkSerial(const std::string& serial) {
+    const auto serial_parsed = android::base::Tokenize(serial, ":");
+    const auto parsed_segments_count = serial_parsed.size();
+    if (parsed_segments_count != 2 && parsed_segments_count != 3) {
+        return Error() << "invalid network address: " << serial << ". Expected format:\n"
+                       << "<protocol>:<address>:<port> (tcp:localhost:5554)";
+    }
+
+    Socket::Protocol protocol;
+    if (serial_parsed[0] == "tcp") {
+        protocol = Socket::Protocol::kTcp;
+    } else if (serial_parsed[0] == "udp") {
+        protocol = Socket::Protocol::kUdp;
+    } else {
+        return Error() << "invalid network address: " << serial << ". Expected format:\n"
+                       << "<protocol>:<address>:<port> (tcp:localhost:5554)";
+    }
+
+    int port = 5554;
+    if (parsed_segments_count == 3) {
+        android::base::ParseInt(serial_parsed[2], &port, 5554);
+    }
+
+    return NetworkSerial{protocol, serial_parsed[1], port};
+}
+
+// Opens a new Transport connected to the particular device.
+// arguments:
 //
-// If |serial| is non-null but invalid, this exits.
-// Otherwise it blocks until the target is available.
+// local_serial - device to connect (can be a network or usb serial name)
+// wait_for_device - flag indicates whether we need to wait for device
+// announce - flag indicates whether we need to print error to stdout in case
+// we cannot connect to the device
 //
 // The returned Transport is a singleton, so multiple calls to this function will return the same
 // object, and the caller should not attempt to delete the returned Transport.
-static Transport* open_device() {
-    bool announce = true;
-
-    Socket::Protocol protocol = Socket::Protocol::kTcp;
-    std::string host;
-    int port = 0;
-    if (serial != nullptr) {
-        const char* net_address = nullptr;
-
-        if (android::base::StartsWith(serial, "tcp:")) {
-            protocol = Socket::Protocol::kTcp;
-            port = tcp::kDefaultPort;
-            net_address = serial + strlen("tcp:");
-        } else if (android::base::StartsWith(serial, "udp:")) {
-            protocol = Socket::Protocol::kUdp;
-            port = udp::kDefaultPort;
-            net_address = serial + strlen("udp:");
-        }
-
-        if (net_address != nullptr) {
-            std::string error;
-            if (!android::base::ParseNetAddress(net_address, &host, &port, nullptr, &error)) {
-                die("invalid network address '%s': %s\n", net_address, error.c_str());
-            }
-        }
-    }
+static Transport* open_device(const char* local_serial, bool wait_for_device = true,
+                              bool announce = true) {
+    const Result<NetworkSerial> network_serial = ParseNetworkSerial(local_serial);
 
     Transport* transport = nullptr;
     while (true) {
-        if (!host.empty()) {
+        if (network_serial.ok()) {
             std::string error;
-            if (protocol == Socket::Protocol::kTcp) {
-                transport = tcp::Connect(host, port, &error).release();
-            } else if (protocol == Socket::Protocol::kUdp) {
-                transport = udp::Connect(host, port, &error).release();
+            if (network_serial->protocol == Socket::Protocol::kTcp) {
+                transport = tcp::Connect(network_serial->address, network_serial->port, &error)
+                                    .release();
+            } else if (network_serial->protocol == Socket::Protocol::kUdp) {
+                transport = udp::Connect(network_serial->address, network_serial->port, &error)
+                                    .release();
             }
 
             if (transport == nullptr && announce) {
-                fprintf(stderr, "error: %s\n", error.c_str());
+                LOG(ERROR) << "error: " << error;
             }
         } else {
-            transport = usb_open(match_fastboot);
+            transport = usb_open(match_fastboot(local_serial));
         }
 
         if (transport != nullptr) {
             return transport;
         }
 
+        if (!wait_for_device) {
+            return nullptr;
+        }
+
         if (announce) {
             announce = false;
-            fprintf(stderr, "< waiting for %s >\n", serial ? serial : "any device");
+            LOG(ERROR) << "< waiting for " << local_serial << ">";
         }
         std::this_thread::sleep_for(std::chrono::milliseconds(1));
     }
 }
 
+static Transport* NetworkDeviceConnected(bool print = false) {
+    Transport* transport = nullptr;
+    Transport* result = nullptr;
+
+    ConnectedDevicesStorage storage;
+    std::set<std::string> devices;
+    {
+        FileLock lock = storage.Lock();
+        devices = storage.ReadDevices(lock);
+    }
+
+    for (const std::string& device : devices) {
+        transport = open_device(device.c_str(), false, false);
+
+        if (print) {
+            PrintDevice(device.c_str(), transport == nullptr ? "offline" : "device");
+        }
+
+        if (transport != nullptr) {
+            result = transport;
+        }
+    }
+
+    return result;
+}
+
+// Detects the fastboot connected device to open a new Transport.
+// Detecting logic:
+//
+// if serial is provided - try to connect to this particular usb/network device
+// othervise:
+// 1. Check connected usb devices and return the last connected one
+// 2. Check connected network devices and return the last connected one
+// 2. If nothing is connected - wait for any device by repeating p. 1 and 2
+//
+// The returned Transport is a singleton, so multiple calls to this function will return the same
+// object, and the caller should not attempt to delete the returned Transport.
+static Transport* open_device() {
+    if (serial != nullptr) {
+        return open_device(serial);
+    }
+
+    bool announce = true;
+    Transport* transport = nullptr;
+    while (true) {
+        transport = usb_open(match_fastboot(nullptr));
+        if (transport != nullptr) {
+            return transport;
+        }
+
+        transport = NetworkDeviceConnected();
+        if (transport != nullptr) {
+            return transport;
+        }
+
+        if (announce) {
+            announce = false;
+            LOG(ERROR) << "< waiting for any device >";
+        }
+        std::this_thread::sleep_for(std::chrono::milliseconds(1));
+    }
+}
+
+static int Connect(int argc, char* argv[]) {
+    if (argc != 1) {
+        LOG(FATAL) << "connect command requires to receive only 1 argument. Usage:" << std::endl
+                   << "fastboot connect [tcp:|udp:host:port]";
+    }
+
+    const char* local_serial = *argv;
+    EXPECT(ParseNetworkSerial(local_serial));
+
+    const Transport* transport = open_device(local_serial, false);
+    if (transport == nullptr) {
+        return 1;
+    }
+
+    ConnectedDevicesStorage storage;
+    {
+        FileLock lock = storage.Lock();
+        std::set<std::string> devices = storage.ReadDevices(lock);
+        devices.insert(local_serial);
+        storage.WriteDevices(lock, devices);
+    }
+
+    return 0;
+}
+
+static int Disconnect(const char* local_serial) {
+    EXPECT(ParseNetworkSerial(local_serial));
+
+    ConnectedDevicesStorage storage;
+    {
+        FileLock lock = storage.Lock();
+        std::set<std::string> devices = storage.ReadDevices(lock);
+        devices.erase(local_serial);
+        storage.WriteDevices(lock, devices);
+    }
+
+    return 0;
+}
+
+static int Disconnect() {
+    ConnectedDevicesStorage storage;
+    {
+        FileLock lock = storage.Lock();
+        storage.Clear(lock);
+    }
+
+    return 0;
+}
+
+static int Disconnect(int argc, char* argv[]) {
+    switch (argc) {
+        case 0: {
+            return Disconnect();
+        }
+        case 1: {
+            return Disconnect(*argv);
+        }
+        default:
+            LOG(FATAL) << "disconnect command can receive only 0 or 1 arguments. Usage:"
+                       << std::endl
+                       << "fastboot disconnect # disconnect all devices" << std::endl
+                       << "fastboot disconnect [tcp:|udp:host:port] # disconnect device";
+    }
+
+    return 0;
+}
+
 static void list_devices() {
     // We don't actually open a USB device here,
     // just getting our callback called so we can
     // list all the connected devices.
     usb_open(list_devices_callback);
+    NetworkDeviceConnected(/* print */ true);
 }
 
 static void syntax_error(const char* fmt, ...) {
@@ -1943,10 +2114,19 @@
     }
 }
 
-static void FastbootLogger(android::base::LogId /* id */, android::base::LogSeverity /* severity */,
+static void FastbootLogger(android::base::LogId /* id */, android::base::LogSeverity severity,
                            const char* /* tag */, const char* /* file */, unsigned int /* line */,
                            const char* message) {
-    verbose("%s", message);
+    switch (severity) {
+        case android::base::INFO:
+            fprintf(stdout, "%s\n", message);
+            break;
+        case android::base::ERROR:
+            fprintf(stderr, "%s\n", message);
+            break;
+        default:
+            verbose("%s\n", message);
+    }
 }
 
 static void FastbootAborter(const char* message) {
@@ -2099,6 +2279,18 @@
         return 0;
     }
 
+    if (argc > 0 && !strcmp(*argv, "connect")) {
+        argc -= optind;
+        argv += optind;
+        return Connect(argc, argv);
+    }
+
+    if (argc > 0 && !strcmp(*argv, "disconnect")) {
+        argc -= optind;
+        argv += optind;
+        return Disconnect(argc, argv);
+    }
+
     if (argc > 0 && !strcmp(*argv, "help")) {
         return show_help();
     }