fastboot: Introduce connect / disconnect for network-connected devices

Use introduced FileLock and network-connected devices storage entities
to introduce fastboot connect / disconnect commands

Test: everything works like discussed here go/fastboot-connect-disconnect on windows/linux
Bug: 267507450
Bug: 267506875
Change-Id: I2d6495ad567a3ddadd471a89b82d78c8c36a3d52
Signed-off-by: Dmitrii Merkurev <dimorinny@google.com>
diff --git a/fastboot/fastboot.cpp b/fastboot/fastboot.cpp
index e739404..0c8747c 100644
--- a/fastboot/fastboot.cpp
+++ b/fastboot/fastboot.cpp
@@ -73,6 +73,7 @@
 #include "diagnose_usb.h"
 #include "fastboot_driver.h"
 #include "fs.h"
+#include "storage.h"
 #include "super_flash_helper.h"
 #include "task.h"
 #include "tcp.h"
@@ -275,8 +276,36 @@
     return 0;
 }
 
-static int match_fastboot(usb_ifc_info* info) {
-    return match_fastboot_with_serial(info, serial);
+static ifc_match_func match_fastboot(const char* local_serial = serial) {
+    return [local_serial](usb_ifc_info* info) -> int {
+        return match_fastboot_with_serial(info, local_serial);
+    };
+}
+
+// output compatible with "adb devices"
+static void PrintDevice(const char* local_serial, const char* status = nullptr,
+                        const char* details = nullptr) {
+    if (local_serial == nullptr || strlen(local_serial) == 0) {
+        return;
+    }
+
+    if (g_long_listing) {
+        printf("%-22s", local_serial);
+    } else {
+        printf("%s\t", local_serial);
+    }
+
+    if (status != nullptr && strlen(status) > 0) {
+        printf(" %s", status);
+    }
+
+    if (g_long_listing) {
+        if (details != nullptr && strlen(details) > 0) {
+            printf(" %s", details);
+        }
+    }
+
+    putchar('\n');
 }
 
 static int list_devices_callback(usb_ifc_info* info) {
@@ -292,88 +321,230 @@
         if (!serial[0]) {
             serial = "????????????";
         }
-        // output compatible with "adb devices"
-        if (!g_long_listing) {
-            printf("%s\t%s", serial.c_str(), interface.c_str());
-        } else {
-            printf("%-22s %s", serial.c_str(), interface.c_str());
-            if (strlen(info->device_path) > 0) printf(" %s", info->device_path);
-        }
-        putchar('\n');
+
+        PrintDevice(serial.c_str(), interface.c_str(), info->device_path);
     }
 
     return -1;
 }
 
-// Opens a new Transport connected to a device. If |serial| is non-null it will be used to identify
-// a specific device, otherwise the first USB device found will be used.
+struct NetworkSerial {
+    Socket::Protocol protocol;
+    std::string address;
+    int port;
+};
+
+static Result<NetworkSerial> ParseNetworkSerial(const std::string& serial) {
+    const auto serial_parsed = android::base::Tokenize(serial, ":");
+    const auto parsed_segments_count = serial_parsed.size();
+    if (parsed_segments_count != 2 && parsed_segments_count != 3) {
+        return Error() << "invalid network address: " << serial << ". Expected format:\n"
+                       << "<protocol>:<address>:<port> (tcp:localhost:5554)";
+    }
+
+    Socket::Protocol protocol;
+    if (serial_parsed[0] == "tcp") {
+        protocol = Socket::Protocol::kTcp;
+    } else if (serial_parsed[0] == "udp") {
+        protocol = Socket::Protocol::kUdp;
+    } else {
+        return Error() << "invalid network address: " << serial << ". Expected format:\n"
+                       << "<protocol>:<address>:<port> (tcp:localhost:5554)";
+    }
+
+    int port = 5554;
+    if (parsed_segments_count == 3) {
+        android::base::ParseInt(serial_parsed[2], &port, 5554);
+    }
+
+    return NetworkSerial{protocol, serial_parsed[1], port};
+}
+
+// Opens a new Transport connected to the particular device.
+// arguments:
 //
-// If |serial| is non-null but invalid, this exits.
-// Otherwise it blocks until the target is available.
+// local_serial - device to connect (can be a network or usb serial name)
+// wait_for_device - flag indicates whether we need to wait for device
+// announce - flag indicates whether we need to print error to stdout in case
+// we cannot connect to the device
 //
 // The returned Transport is a singleton, so multiple calls to this function will return the same
 // object, and the caller should not attempt to delete the returned Transport.
-static Transport* open_device() {
-    bool announce = true;
-
-    Socket::Protocol protocol = Socket::Protocol::kTcp;
-    std::string host;
-    int port = 0;
-    if (serial != nullptr) {
-        const char* net_address = nullptr;
-
-        if (android::base::StartsWith(serial, "tcp:")) {
-            protocol = Socket::Protocol::kTcp;
-            port = tcp::kDefaultPort;
-            net_address = serial + strlen("tcp:");
-        } else if (android::base::StartsWith(serial, "udp:")) {
-            protocol = Socket::Protocol::kUdp;
-            port = udp::kDefaultPort;
-            net_address = serial + strlen("udp:");
-        }
-
-        if (net_address != nullptr) {
-            std::string error;
-            if (!android::base::ParseNetAddress(net_address, &host, &port, nullptr, &error)) {
-                die("invalid network address '%s': %s\n", net_address, error.c_str());
-            }
-        }
-    }
+static Transport* open_device(const char* local_serial, bool wait_for_device = true,
+                              bool announce = true) {
+    const Result<NetworkSerial> network_serial = ParseNetworkSerial(local_serial);
 
     Transport* transport = nullptr;
     while (true) {
-        if (!host.empty()) {
+        if (network_serial.ok()) {
             std::string error;
-            if (protocol == Socket::Protocol::kTcp) {
-                transport = tcp::Connect(host, port, &error).release();
-            } else if (protocol == Socket::Protocol::kUdp) {
-                transport = udp::Connect(host, port, &error).release();
+            if (network_serial->protocol == Socket::Protocol::kTcp) {
+                transport = tcp::Connect(network_serial->address, network_serial->port, &error)
+                                    .release();
+            } else if (network_serial->protocol == Socket::Protocol::kUdp) {
+                transport = udp::Connect(network_serial->address, network_serial->port, &error)
+                                    .release();
             }
 
             if (transport == nullptr && announce) {
-                fprintf(stderr, "error: %s\n", error.c_str());
+                LOG(ERROR) << "error: " << error;
             }
         } else {
-            transport = usb_open(match_fastboot);
+            transport = usb_open(match_fastboot(local_serial));
         }
 
         if (transport != nullptr) {
             return transport;
         }
 
+        if (!wait_for_device) {
+            return nullptr;
+        }
+
         if (announce) {
             announce = false;
-            fprintf(stderr, "< waiting for %s >\n", serial ? serial : "any device");
+            LOG(ERROR) << "< waiting for " << local_serial << ">";
         }
         std::this_thread::sleep_for(std::chrono::milliseconds(1));
     }
 }
 
+static Transport* NetworkDeviceConnected(bool print = false) {
+    Transport* transport = nullptr;
+    Transport* result = nullptr;
+
+    ConnectedDevicesStorage storage;
+    std::set<std::string> devices;
+    {
+        FileLock lock = storage.Lock();
+        devices = storage.ReadDevices(lock);
+    }
+
+    for (const std::string& device : devices) {
+        transport = open_device(device.c_str(), false, false);
+
+        if (print) {
+            PrintDevice(device.c_str(), transport == nullptr ? "offline" : "device");
+        }
+
+        if (transport != nullptr) {
+            result = transport;
+        }
+    }
+
+    return result;
+}
+
+// Detects the fastboot connected device to open a new Transport.
+// Detecting logic:
+//
+// if serial is provided - try to connect to this particular usb/network device
+// othervise:
+// 1. Check connected usb devices and return the last connected one
+// 2. Check connected network devices and return the last connected one
+// 2. If nothing is connected - wait for any device by repeating p. 1 and 2
+//
+// The returned Transport is a singleton, so multiple calls to this function will return the same
+// object, and the caller should not attempt to delete the returned Transport.
+static Transport* open_device() {
+    if (serial != nullptr) {
+        return open_device(serial);
+    }
+
+    bool announce = true;
+    Transport* transport = nullptr;
+    while (true) {
+        transport = usb_open(match_fastboot(nullptr));
+        if (transport != nullptr) {
+            return transport;
+        }
+
+        transport = NetworkDeviceConnected();
+        if (transport != nullptr) {
+            return transport;
+        }
+
+        if (announce) {
+            announce = false;
+            LOG(ERROR) << "< waiting for any device >";
+        }
+        std::this_thread::sleep_for(std::chrono::milliseconds(1));
+    }
+}
+
+static int Connect(int argc, char* argv[]) {
+    if (argc != 1) {
+        LOG(FATAL) << "connect command requires to receive only 1 argument. Usage:" << std::endl
+                   << "fastboot connect [tcp:|udp:host:port]";
+    }
+
+    const char* local_serial = *argv;
+    EXPECT(ParseNetworkSerial(local_serial));
+
+    const Transport* transport = open_device(local_serial, false);
+    if (transport == nullptr) {
+        return 1;
+    }
+
+    ConnectedDevicesStorage storage;
+    {
+        FileLock lock = storage.Lock();
+        std::set<std::string> devices = storage.ReadDevices(lock);
+        devices.insert(local_serial);
+        storage.WriteDevices(lock, devices);
+    }
+
+    return 0;
+}
+
+static int Disconnect(const char* local_serial) {
+    EXPECT(ParseNetworkSerial(local_serial));
+
+    ConnectedDevicesStorage storage;
+    {
+        FileLock lock = storage.Lock();
+        std::set<std::string> devices = storage.ReadDevices(lock);
+        devices.erase(local_serial);
+        storage.WriteDevices(lock, devices);
+    }
+
+    return 0;
+}
+
+static int Disconnect() {
+    ConnectedDevicesStorage storage;
+    {
+        FileLock lock = storage.Lock();
+        storage.Clear(lock);
+    }
+
+    return 0;
+}
+
+static int Disconnect(int argc, char* argv[]) {
+    switch (argc) {
+        case 0: {
+            return Disconnect();
+        }
+        case 1: {
+            return Disconnect(*argv);
+        }
+        default:
+            LOG(FATAL) << "disconnect command can receive only 0 or 1 arguments. Usage:"
+                       << std::endl
+                       << "fastboot disconnect # disconnect all devices" << std::endl
+                       << "fastboot disconnect [tcp:|udp:host:port] # disconnect device";
+    }
+
+    return 0;
+}
+
 static void list_devices() {
     // We don't actually open a USB device here,
     // just getting our callback called so we can
     // list all the connected devices.
     usb_open(list_devices_callback);
+    NetworkDeviceConnected(/* print */ true);
 }
 
 static void syntax_error(const char* fmt, ...) {
@@ -1943,10 +2114,19 @@
     }
 }
 
-static void FastbootLogger(android::base::LogId /* id */, android::base::LogSeverity /* severity */,
+static void FastbootLogger(android::base::LogId /* id */, android::base::LogSeverity severity,
                            const char* /* tag */, const char* /* file */, unsigned int /* line */,
                            const char* message) {
-    verbose("%s", message);
+    switch (severity) {
+        case android::base::INFO:
+            fprintf(stdout, "%s\n", message);
+            break;
+        case android::base::ERROR:
+            fprintf(stderr, "%s\n", message);
+            break;
+        default:
+            verbose("%s\n", message);
+    }
 }
 
 static void FastbootAborter(const char* message) {
@@ -2099,6 +2279,18 @@
         return 0;
     }
 
+    if (argc > 0 && !strcmp(*argv, "connect")) {
+        argc -= optind;
+        argv += optind;
+        return Connect(argc, argv);
+    }
+
+    if (argc > 0 && !strcmp(*argv, "disconnect")) {
+        argc -= optind;
+        argv += optind;
+        return Disconnect(argc, argv);
+    }
+
     if (argc > 0 && !strcmp(*argv, "help")) {
         return show_help();
     }
diff --git a/fastboot/filesystem.cpp b/fastboot/filesystem.cpp
index a58ba00..94fde8e 100644
--- a/fastboot/filesystem.cpp
+++ b/fastboot/filesystem.cpp
@@ -38,15 +38,15 @@
 #ifdef _WIN32
     HANDLE handle = reinterpret_cast<HANDLE>(_get_osfhandle(fd));
     OVERLAPPED overlapped = {};
-    const BOOL locked = LockFileEx(handle, LOCKFILE_EXCLUSIVE_LOCK, 0,
-                                   MAXDWORD, MAXDWORD, &overlapped);
+    const BOOL locked =
+            LockFileEx(handle, LOCKFILE_EXCLUSIVE_LOCK, 0, MAXDWORD, MAXDWORD, &overlapped);
     return locked ? 0 : -1;
 #else
     return flock(fd, LOCK_EX);
 #endif
 }
 
-}
+}  // namespace
 
 // inspired by adb implementation:
 // cs.android.com/android/platform/superproject/+/master:packages/modules/adb/adb_utils.cpp;l=275
@@ -90,9 +90,9 @@
 bool EnsureDirectoryExists(const std::string& directory_path) {
     const int result =
 #ifdef _WIN32
-                       _mkdir(directory_path.c_str());
+            _mkdir(directory_path.c_str());
 #else
-                       mkdir(directory_path.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH);
+            mkdir(directory_path.c_str(), S_IRWXU | S_IRWXG | S_IROTH | S_IXOTH);
 #endif
 
     return result == 0 || errno == EEXIST;
diff --git a/fastboot/filesystem.h b/fastboot/filesystem.h
index 3261496..5f41fbc 100644
--- a/fastboot/filesystem.h
+++ b/fastboot/filesystem.h
@@ -25,9 +25,9 @@
 // TODO(b/175635923): remove after enabling libc++fs for windows
 const char kPathSeparator =
 #ifdef _WIN32
-                            '\\';
+        '\\';
 #else
-                            '/';
+        '/';
 #endif
 
 std::string GetHomeDirPath();
diff --git a/fastboot/storage.cpp b/fastboot/storage.cpp
index db13dd6..d6e00cf 100644
--- a/fastboot/storage.cpp
+++ b/fastboot/storage.cpp
@@ -43,20 +43,20 @@
     devices_lock_path_ = home_fastboot_path + kPathSeparator + "devices.lock";
 }
 
-void ConnectedDevicesStorage::WriteDevices(const std::set<std::string>& devices) {
+void ConnectedDevicesStorage::WriteDevices(const FileLock&, const std::set<std::string>& devices) {
     std::ofstream devices_stream(devices_path_);
     std::copy(devices.begin(), devices.end(),
               std::ostream_iterator<std::string>(devices_stream, "\n"));
 }
 
-std::set<std::string> ConnectedDevicesStorage::ReadDevices() {
+std::set<std::string> ConnectedDevicesStorage::ReadDevices(const FileLock&) {
     std::ifstream devices_stream(devices_path_);
     std::istream_iterator<std::string> start(devices_stream), end;
     std::set<std::string> devices(start, end);
     return devices;
 }
 
-void ConnectedDevicesStorage::Clear() {
+void ConnectedDevicesStorage::Clear(const FileLock&) {
     if (!android::base::RemoveFileIfExists(devices_path_)) {
         LOG(FATAL) << "Failed to clear connected device list: " << devices_path_;
     }
diff --git a/fastboot/storage.h b/fastboot/storage.h
index 1cce950..0cc3d86 100644
--- a/fastboot/storage.h
+++ b/fastboot/storage.h
@@ -24,11 +24,12 @@
 class ConnectedDevicesStorage {
   public:
     ConnectedDevicesStorage();
-    void WriteDevices(const std::set<std::string>& devices);
-    std::set<std::string> ReadDevices();
-    void Clear();
+    void WriteDevices(const FileLock&, const std::set<std::string>& devices);
+    std::set<std::string> ReadDevices(const FileLock&);
+    void Clear(const FileLock&);
 
     FileLock Lock() const;
+
   private:
     std::string devices_path_;
     std::string devices_lock_path_;
diff --git a/fastboot/usb.h b/fastboot/usb.h
index e5f56e2..69581ab 100644
--- a/fastboot/usb.h
+++ b/fastboot/usb.h
@@ -28,6 +28,8 @@
 
 #pragma once
 
+#include <functional>
+
 #include "transport.h"
 
 struct usb_ifc_info {
@@ -61,7 +63,7 @@
     virtual int Reset() = 0;
 };
 
-typedef int (*ifc_match_func)(usb_ifc_info *ifc);
+typedef std::function<int(usb_ifc_info*)> ifc_match_func;
 
 // 0 is non blocking
 UsbTransport* usb_open(ifc_match_func callback, uint32_t timeout_ms = 0);
diff --git a/fastboot/util.h b/fastboot/util.h
index 290d0d5..bc01473 100644
--- a/fastboot/util.h
+++ b/fastboot/util.h
@@ -6,11 +6,21 @@
 #include <string>
 #include <vector>
 
+#include <android-base/logging.h>
+#include <android-base/result.h>
 #include <android-base/unique_fd.h>
 #include <bootimg.h>
 #include <liblp/liblp.h>
 #include <sparse/sparse.h>
 
+using android::base::ErrnoError;
+using android::base::Error;
+using android::base::Result;
+using android::base::ResultError;
+
+#define EXPECT(result) \
+    (result.ok() ? result.value() : (LOG(FATAL) << result.error().message(), result.value()))
+
 using SparsePtr = std::unique_ptr<sparse_file, decltype(&sparse_file_destroy)>;
 
 /* util stuff */