'adb connect' by mDNS service name.

Bug: 152886765

Test: $ANDROID_HOST_OUT/nativetest64/adb_test/adb_test
Test: test_adb.py

Change-Id: I7e93ceca7cdf913060bbc5afe824593a9922c6d9
diff --git a/adb/adb_wifi.h b/adb/adb_wifi.h
index 3a6b0b1..5113c15 100644
--- a/adb/adb_wifi.h
+++ b/adb/adb_wifi.h
@@ -16,6 +16,7 @@
 
 #pragma once
 
+#include <optional>
 #include <string>
 
 #include "adb.h"
@@ -30,6 +31,18 @@
 std::string mdns_check();
 std::string mdns_list_discovered_services();
 
+struct MdnsInfo {
+    std::string service_name;
+    std::string service_type;
+    std::string addr;
+    uint16_t port = 0;
+
+    MdnsInfo(std::string_view name, std::string_view type, std::string_view addr, uint16_t port)
+        : service_name(name), service_type(type), addr(addr), port(port) {}
+};
+
+std::optional<MdnsInfo> mdns_get_connect_service_info(std::string_view name);
+
 #else  // !ADB_HOST
 
 struct AdbdAuthContext;
diff --git a/adb/client/transport_mdns.cpp b/adb/client/transport_mdns.cpp
index c9993b7..f8c9de3 100644
--- a/adb/client/transport_mdns.cpp
+++ b/adb/client/transport_mdns.cpp
@@ -38,6 +38,7 @@
 #include "adb_trace.h"
 #include "adb_utils.h"
 #include "adb_wifi.h"
+#include "client/mdns_utils.h"
 #include "fdevent/fdevent.h"
 #include "sysdeps.h"
 
@@ -319,7 +320,7 @@
 
     static void initAdbServiceRegistries();
 
-    static void forEachService(const ServiceRegistry& services, const std::string& hostname,
+    static void forEachService(const ServiceRegistry& services, std::string_view hostname,
                                adb_secure_foreach_service_callback cb);
 
     static bool connectByServiceName(const ServiceRegistry& services,
@@ -362,7 +363,7 @@
 
 // static
 void ResolvedService::forEachService(const ServiceRegistry& services,
-                                     const std::string& wanted_service_name,
+                                     std::string_view wanted_service_name,
                                      adb_secure_foreach_service_callback cb) {
     initAdbServiceRegistries();
 
@@ -372,7 +373,7 @@
         auto ip = service->ipAddress();
         auto port = service->port();
 
-        if (wanted_service_name == "") {
+        if (wanted_service_name.empty()) {
             cb(service_name.c_str(), reg_type.c_str(), ip.c_str(), port);
         } else if (service_name == wanted_service_name) {
             cb(service_name.c_str(), reg_type.c_str(), ip.c_str(), port);
@@ -396,14 +397,12 @@
 
 void adb_secure_foreach_pairing_service(const char* service_name,
                                         adb_secure_foreach_service_callback cb) {
-    ResolvedService::forEachService(*ResolvedService::sAdbSecurePairingServices,
-                                    service_name ? service_name : "", cb);
+    ResolvedService::forEachService(*ResolvedService::sAdbSecurePairingServices, service_name, cb);
 }
 
 void adb_secure_foreach_connect_service(const char* service_name,
                                         adb_secure_foreach_service_callback cb) {
-    ResolvedService::forEachService(*ResolvedService::sAdbSecureConnectServices,
-                                    service_name ? service_name : "", cb);
+    ResolvedService::forEachService(*ResolvedService::sAdbSecureConnectServices, service_name, cb);
 }
 
 bool adb_secure_connect_by_service_name(const char* service_name) {
@@ -676,3 +675,48 @@
     ResolvedService::forEachService(*ResolvedService::sAdbSecurePairingServices, "", cb);
     return result;
 }
+
+std::optional<MdnsInfo> mdns_get_connect_service_info(std::string_view name) {
+    CHECK(!name.empty());
+
+    auto mdns_instance = mdns::mdns_parse_instance_name(name);
+    if (!mdns_instance.has_value()) {
+        D("Failed to parse mDNS name [%s]", name.data());
+        return std::nullopt;
+    }
+
+    std::optional<MdnsInfo> info;
+    auto cb = [&](const char* service_name, const char* reg_type, const char* ip_addr,
+                  uint16_t port) { info.emplace(service_name, reg_type, ip_addr, port); };
+
+    std::string reg_type;
+    if (!mdns_instance->service_name.empty()) {
+        reg_type = android::base::StringPrintf("%s.%s", mdns_instance->service_name.data(),
+                                               mdns_instance->transport_type.data());
+        int index = adb_DNSServiceIndexByName(reg_type);
+        switch (index) {
+            case kADBTransportServiceRefIndex:
+                ResolvedService::forEachService(*ResolvedService::sAdbTransportServices,
+                                                mdns_instance->instance_name, cb);
+                break;
+            case kADBSecureConnectServiceRefIndex:
+                ResolvedService::forEachService(*ResolvedService::sAdbSecureConnectServices,
+                                                mdns_instance->instance_name, cb);
+                break;
+            default:
+                D("Unknown reg_type [%s]", reg_type.data());
+                return std::nullopt;
+        }
+        return info;
+    }
+
+    for (const auto& service :
+         {ResolvedService::sAdbTransportServices, ResolvedService::sAdbSecureConnectServices}) {
+        ResolvedService::forEachService(*service, name, cb);
+        if (info.has_value()) {
+            return info;
+        }
+    }
+
+    return std::nullopt;
+}
diff --git a/adb/socket_spec.cpp b/adb/socket_spec.cpp
index d17036c..b7fd493 100644
--- a/adb/socket_spec.cpp
+++ b/adb/socket_spec.cpp
@@ -30,6 +30,7 @@
 
 #include "adb.h"
 #include "adb_utils.h"
+#include "adb_wifi.h"
 #include "sysdeps.h"
 
 using namespace std::string_literals;
@@ -201,7 +202,24 @@
             fd->reset(network_loopback_client(port_value, SOCK_STREAM, error));
         } else {
 #if ADB_HOST
-            fd->reset(network_connect(hostname, port_value, SOCK_STREAM, 0, error));
+            // Check if the address is an mdns service we can connect to.
+            if (auto mdns_info = mdns_get_connect_service_info(address.substr(4));
+                mdns_info != std::nullopt) {
+                fd->reset(network_connect(mdns_info->addr, mdns_info->port, SOCK_STREAM, 0, error));
+                if (fd->get() != -1) {
+                    // TODO(joshuaduong): We still show the ip address for the serial. Change it to
+                    // use the mdns instance name, so we can adjust to address changes on
+                    // reconnects.
+                    port_value = mdns_info->port;
+                    if (serial) {
+                        *serial = android::base::StringPrintf("%s.%s",
+                                                              mdns_info->service_name.c_str(),
+                                                              mdns_info->service_type.c_str());
+                    }
+                }
+            } else {
+                fd->reset(network_connect(hostname, port_value, SOCK_STREAM, 0, error));
+            }
 #else
             // Disallow arbitrary connections in adbd.
             *error = "adbd does not support arbitrary tcp connections";
diff --git a/adb/test_adb.py b/adb/test_adb.py
index 9912f11..4b99411 100755
--- a/adb/test_adb.py
+++ b/adb/test_adb.py
@@ -25,6 +25,7 @@
 import random
 import select
 import socket
+import string
 import struct
 import subprocess
 import sys
@@ -628,21 +629,49 @@
 class MdnsTest(unittest.TestCase):
     """Tests for adb mdns."""
 
+    @staticmethod
+    def _mdns_services(port):
+        output = subprocess.check_output(["adb", "-P", str(port), "mdns", "services"])
+        return [x.split("\t") for x in output.decode("utf8").strip().splitlines()[1:]]
+
+    @staticmethod
+    def _devices(port):
+        output = subprocess.check_output(["adb", "-P", str(port), "devices"])
+        return [x.split("\t") for x in output.decode("utf8").strip().splitlines()[1:]]
+
+    @contextlib.contextmanager
+    def _adb_mdns_connect(self, server_port, mdns_instance, serial, should_connect):
+        """Context manager for an ADB connection.
+
+        This automatically disconnects when done with the connection.
+        """
+
+        output = subprocess.check_output(["adb", "-P", str(server_port), "connect", mdns_instance])
+        if should_connect:
+            self.assertEqual(output.strip(), "connected to {}".format(serial).encode("utf8"))
+        else:
+            self.assertTrue(output.startswith("failed to resolve host: '{}'"
+                .format(mdns_instance).encode("utf8")))
+
+        try:
+            yield
+        finally:
+            # Perform best-effort disconnection. Discard the output.
+            subprocess.Popen(["adb", "disconnect", serial],
+                             stdout=subprocess.PIPE,
+                             stderr=subprocess.PIPE).communicate()
+
+
     @unittest.skipIf(not is_zeroconf_installed(), "zeroconf library not installed")
     def test_mdns_services_register_unregister(self):
         """Ensure that `adb mdns services` correctly adds and removes a service
         """
         from zeroconf import IPVersion, ServiceInfo
- 
-        def _mdns_services(port):
-            output = subprocess.check_output(["adb", "-P", str(port), "mdns", "services"])
-            return [x.split("\t") for x in output.decode("utf8").strip().splitlines()[1:]]
 
         with adb_server() as server_port:
             output = subprocess.check_output(["adb", "-P", str(server_port),
                                               "mdns", "services"]).strip()
             self.assertTrue(output.startswith(b"List of discovered mdns services"))
-            print(f"services={_mdns_services(server_port)}")
 
             """TODO(joshuaduong): Add ipv6 tests once we have it working in adb"""
             """Register/Unregister a service"""
@@ -656,20 +685,52 @@
                         name=serv_instance + "." + serv_type + "local.",
                         addresses=[serv_ipaddr],
                         port=serv_port)
-                print(f"Registering {serv_instance}.{serv_type} ...")
                 with zeroconf_register_service(zc, service_info) as info:
                     """Give adb some time to register the service"""
                     time.sleep(1)
-                    print(f"services={_mdns_services(server_port)}")
                     self.assertTrue(any((serv_instance in line and serv_type in line)
-                        for line in _mdns_services(server_port)))
+                        for line in MdnsTest._mdns_services(server_port)))
 
                 """Give adb some time to unregister the service"""
-                print("Unregistering mdns service...")
                 time.sleep(1)
-                print(f"services={_mdns_services(server_port)}")
                 self.assertFalse(any((serv_instance in line and serv_type in line)
-                    for line in _mdns_services(server_port)))
+                    for line in MdnsTest._mdns_services(server_port)))
+
+    @unittest.skipIf(not is_zeroconf_installed(), "zeroconf library not installed")
+    def test_mdns_connect(self):
+        """Ensure that `adb connect` by mdns instance name works (for non-pairing services)
+        """
+        from zeroconf import IPVersion, ServiceInfo
+
+        with adb_server() as server_port:
+            with zeroconf_context(IPVersion.V4Only) as zc:
+                serv_instance = "fakeadbd-" + ''.join(
+                        random.choice(string.ascii_letters) for i in range(4))
+                serv_type = "_" + self.service_name + "._tcp."
+                serv_ipaddr = socket.inet_aton("127.0.0.1")
+                should_connect = self.service_name != "adb-tls-pairing"
+                with fake_adbd() as (port, _):
+                    service_info = ServiceInfo(
+                            serv_type + "local.",
+                            name=serv_instance + "." + serv_type + "local.",
+                            addresses=[serv_ipaddr],
+                            port=port)
+                    with zeroconf_register_service(zc, service_info) as info:
+                        """Give adb some time to register the service"""
+                        time.sleep(1)
+                        self.assertTrue(any((serv_instance in line and serv_type in line)
+                            for line in MdnsTest._mdns_services(server_port)))
+                        full_name = '.'.join([serv_instance, serv_type])
+                        with self._adb_mdns_connect(server_port, serv_instance, full_name,
+                                should_connect):
+                            if should_connect:
+                                self.assertEqual(MdnsTest._devices(server_port),
+                                        [[full_name, "device"]])
+
+                    """Give adb some time to unregister the service"""
+                    time.sleep(1)
+                    self.assertFalse(any((serv_instance in line and serv_type in line)
+                        for line in MdnsTest._mdns_services(server_port)))
 
 def main():
     """Main entrypoint."""