[adbwifi] Add A_STLS command.

This command will be sent by adbd to notify the client that the
connection will be over TLS.

When client connects, it will send the CNXN packet, as usual. If the
server connection has TLS enabled, it will send the A_STLS packet
(regardless of whether auth is required). At this point, the client's
only valid response is to send a A_STLS packet. Once both sides have
exchanged the A_STLS packet, both will start the TLS handshake.

If auth is required, then the client will receive a CertificateRequest
with a list of known public keys (SHA256 hash) that it can use in its
certificate. Otherwise, the list will be empty and the client can assume
that either any key will work, or none will work.

If the handshake was successful, the server will send the CNXN packet
and the usual adb protocol is resumed over TLS. If the handshake failed,
both sides will disconnect, as there's no point to retry because the
server's known keys have already been communicated.

Bug: 111434128

Test: WIP; will add to adb_test.py/adb_device.py.

Enable wireless debugging in the Settings, then 'adb connect
<ip>:<port>'. Connection should succeed if key is in keystore. Used
wireshark to check for packet encryption.

Change-Id: I3d60647491c6c6b92297e4f628707a6457fa9420
diff --git a/adb/transport.cpp b/adb/transport.cpp
index 9dd6ec6..8b3461a 100644
--- a/adb/transport.cpp
+++ b/adb/transport.cpp
@@ -36,6 +36,9 @@
 #include <set>
 #include <thread>
 
+#include <adb/crypto/rsa_2048_key.h>
+#include <adb/crypto/x509_generator.h>
+#include <adb/tls/tls_connection.h>
 #include <android-base/logging.h>
 #include <android-base/parsenetaddress.h>
 #include <android-base/stringprintf.h>
@@ -52,7 +55,10 @@
 #include "fdevent/fdevent.h"
 #include "sysdeps/chrono.h"
 
+using namespace adb::crypto;
+using namespace adb::tls;
 using android::base::ScopedLockAssertion;
+using TlsError = TlsConnection::TlsError;
 
 static void remove_transport(atransport* transport);
 static void transport_destroy(atransport* transport);
@@ -279,18 +285,7 @@
                    << "): started multiple times";
     }
 
-    read_thread_ = std::thread([this]() {
-        LOG(INFO) << this->transport_name_ << ": read thread spawning";
-        while (true) {
-            auto packet = std::make_unique<apacket>();
-            if (!underlying_->Read(packet.get())) {
-                PLOG(INFO) << this->transport_name_ << ": read failed";
-                break;
-            }
-            read_callback_(this, std::move(packet));
-        }
-        std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "read failed"); });
-    });
+    StartReadThread();
 
     write_thread_ = std::thread([this]() {
         LOG(INFO) << this->transport_name_ << ": write thread spawning";
@@ -319,6 +314,46 @@
     started_ = true;
 }
 
+void BlockingConnectionAdapter::StartReadThread() {
+    read_thread_ = std::thread([this]() {
+        LOG(INFO) << this->transport_name_ << ": read thread spawning";
+        while (true) {
+            auto packet = std::make_unique<apacket>();
+            if (!underlying_->Read(packet.get())) {
+                PLOG(INFO) << this->transport_name_ << ": read failed";
+                break;
+            }
+
+            bool got_stls_cmd = false;
+            if (packet->msg.command == A_STLS) {
+                got_stls_cmd = true;
+            }
+
+            read_callback_(this, std::move(packet));
+
+            // If we received the STLS packet, we are about to perform the TLS
+            // handshake. So this read thread must stop and resume after the
+            // handshake completes otherwise this will interfere in the process.
+            if (got_stls_cmd) {
+                LOG(INFO) << this->transport_name_
+                          << ": Received STLS packet. Stopping read thread.";
+                return;
+            }
+        }
+        std::call_once(this->error_flag_, [this]() { this->error_callback_(this, "read failed"); });
+    });
+}
+
+bool BlockingConnectionAdapter::DoTlsHandshake(RSA* key, std::string* auth_key) {
+    std::lock_guard<std::mutex> lock(mutex_);
+    if (read_thread_.joinable()) {
+        read_thread_.join();
+    }
+    bool success = this->underlying_->DoTlsHandshake(key, auth_key);
+    StartReadThread();
+    return success;
+}
+
 void BlockingConnectionAdapter::Reset() {
     {
         std::lock_guard<std::mutex> lock(mutex_);
@@ -388,8 +423,36 @@
     return true;
 }
 
+FdConnection::FdConnection(unique_fd fd) : fd_(std::move(fd)) {}
+
+FdConnection::~FdConnection() {}
+
+bool FdConnection::DispatchRead(void* buf, size_t len) {
+    if (tls_ != nullptr) {
+        // The TlsConnection doesn't allow 0 byte reads
+        if (len == 0) {
+            return true;
+        }
+        return tls_->ReadFully(buf, len);
+    }
+
+    return ReadFdExactly(fd_.get(), buf, len);
+}
+
+bool FdConnection::DispatchWrite(void* buf, size_t len) {
+    if (tls_ != nullptr) {
+        // The TlsConnection doesn't allow 0 byte writes
+        if (len == 0) {
+            return true;
+        }
+        return tls_->WriteFully(std::string_view(reinterpret_cast<const char*>(buf), len));
+    }
+
+    return WriteFdExactly(fd_.get(), buf, len);
+}
+
 bool FdConnection::Read(apacket* packet) {
-    if (!ReadFdExactly(fd_.get(), &packet->msg, sizeof(amessage))) {
+    if (!DispatchRead(&packet->msg, sizeof(amessage))) {
         D("remote local: read terminated (message)");
         return false;
     }
@@ -401,7 +464,7 @@
 
     packet->payload.resize(packet->msg.data_length);
 
-    if (!ReadFdExactly(fd_.get(), &packet->payload[0], packet->payload.size())) {
+    if (!DispatchRead(&packet->payload[0], packet->payload.size())) {
         D("remote local: terminated (data)");
         return false;
     }
@@ -410,13 +473,13 @@
 }
 
 bool FdConnection::Write(apacket* packet) {
-    if (!WriteFdExactly(fd_.get(), &packet->msg, sizeof(packet->msg))) {
+    if (!DispatchWrite(&packet->msg, sizeof(packet->msg))) {
         D("remote local: write terminated");
         return false;
     }
 
     if (packet->msg.data_length) {
-        if (!WriteFdExactly(fd_.get(), &packet->payload[0], packet->msg.data_length)) {
+        if (!DispatchWrite(&packet->payload[0], packet->msg.data_length)) {
             D("remote local: write terminated");
             return false;
         }
@@ -425,6 +488,51 @@
     return true;
 }
 
+bool FdConnection::DoTlsHandshake(RSA* key, std::string* auth_key) {
+    bssl::UniquePtr<EVP_PKEY> evp_pkey(EVP_PKEY_new());
+    if (!EVP_PKEY_set1_RSA(evp_pkey.get(), key)) {
+        LOG(ERROR) << "EVP_PKEY_set1_RSA failed";
+        return false;
+    }
+    auto x509 = GenerateX509Certificate(evp_pkey.get());
+    auto x509_str = X509ToPEMString(x509.get());
+    auto evp_str = Key::ToPEMString(evp_pkey.get());
+#if ADB_HOST
+    tls_ = TlsConnection::Create(TlsConnection::Role::Client,
+#else
+    tls_ = TlsConnection::Create(TlsConnection::Role::Server,
+#endif
+                                 x509_str, evp_str, fd_);
+    CHECK(tls_);
+#if ADB_HOST
+    // TLS 1.3 gives the client no message if the server rejected the
+    // certificate. This will enable a check in the tls connection to check
+    // whether the client certificate got rejected. Note that this assumes
+    // that, on handshake success, the server speaks first.
+    tls_->EnableClientPostHandshakeCheck(true);
+    // Add callback to set the certificate when server issues the
+    // CertificateRequest.
+    tls_->SetCertificateCallback(adb_tls_set_certificate);
+    // Allow any server certificate
+    tls_->SetCertVerifyCallback([](X509_STORE_CTX*) { return 1; });
+#else
+    // Add callback to check certificate against a list of known public keys
+    tls_->SetCertVerifyCallback(
+            [auth_key](X509_STORE_CTX* ctx) { return adbd_tls_verify_cert(ctx, auth_key); });
+    // Add the list of allowed client CA issuers
+    auto ca_list = adbd_tls_client_ca_list();
+    tls_->SetClientCAList(ca_list.get());
+#endif
+
+    auto err = tls_->DoHandshake();
+    if (err == TlsError::Success) {
+        return true;
+    }
+
+    tls_.reset();
+    return false;
+}
+
 void FdConnection::Close() {
     adb_shutdown(fd_.get());
     fd_.reset();
@@ -750,6 +858,26 @@
     }
 }
 
+void kick_all_tcp_tls_transports() {
+    std::lock_guard<std::recursive_mutex> lock(transport_lock);
+    for (auto t : transport_list) {
+        if (t->IsTcpDevice() && t->use_tls) {
+            t->Kick();
+        }
+    }
+}
+
+#if !ADB_HOST
+void kick_all_transports_by_auth_key(std::string_view auth_key) {
+    std::lock_guard<std::recursive_mutex> lock(transport_lock);
+    for (auto t : transport_list) {
+        if (auth_key == t->auth_key) {
+            t->Kick();
+        }
+    }
+}
+#endif
+
 /* the fdevent select pump is single threaded */
 void register_transport(atransport* transport) {
     tmsg m;
@@ -1026,6 +1154,10 @@
     return protocol_version;
 }
 
+int atransport::get_tls_version() const {
+    return tls_version;
+}
+
 size_t atransport::get_max_payload() const {
     return max_payload;
 }
@@ -1221,8 +1353,9 @@
 #endif  // ADB_HOST
 
 bool register_socket_transport(unique_fd s, std::string serial, int port, int local,
-                               atransport::ReconnectCallback reconnect, int* error) {
+                               atransport::ReconnectCallback reconnect, bool use_tls, int* error) {
     atransport* t = new atransport(std::move(reconnect), kCsOffline);
+    t->use_tls = use_tls;
 
     D("transport: %s init'ing for socket %d, on port %d", serial.c_str(), s.get(), port);
     if (init_socket_transport(t, std::move(s), port, local) < 0) {
@@ -1360,6 +1493,15 @@
 }
 
 #if ADB_HOST
+std::shared_ptr<RSA> atransport::Key() {
+    if (keys_.empty()) {
+        return nullptr;
+    }
+
+    std::shared_ptr<RSA> result = keys_[0];
+    return result;
+}
+
 std::shared_ptr<RSA> atransport::NextKey() {
     if (keys_.empty()) {
         LOG(INFO) << "fetching keys for transport " << this->serial_name();
@@ -1367,10 +1509,11 @@
 
         // We should have gotten at least one key: the one that's automatically generated.
         CHECK(!keys_.empty());
+    } else {
+        keys_.pop_front();
     }
 
     std::shared_ptr<RSA> result = keys_[0];
-    keys_.pop_front();
     return result;
 }