adb: Make the Connection object a std::shared_ptr
This change is in preparation to allow the TCP-based transports to be
able to reconnect. This is needed because multiple threads can access
the Connection object. It used to be safe to do because one instance of
atransport would have the same Connection instance throughout its
lifetime, but now it is possible to replace the Connection instance,
which could cause threads that were attempting to Write to an
atransport* to use-after-free the Connection instance.
Bug: 74411879
Test: system/core/adb/test_adb.py
Change-Id: I4f092be11b2095088a9a9de2c0386086814d37ce
diff --git a/adb/transport.cpp b/adb/transport.cpp
index 0ab428e..706aee6 100644
--- a/adb/transport.cpp
+++ b/adb/transport.cpp
@@ -517,8 +517,8 @@
if (t->GetConnectionState() != kCsNoPerm) {
/* initial references are the two threads */
t->ref_count = 1;
- t->connection->SetTransportName(t->serial_name());
- t->connection->SetReadCallback([t](Connection*, std::unique_ptr<apacket> p) {
+ t->connection()->SetTransportName(t->serial_name());
+ t->connection()->SetReadCallback([t](Connection*, std::unique_ptr<apacket> p) {
if (!check_header(p.get(), t)) {
D("%s: remote read: bad header", t->serial);
return false;
@@ -531,7 +531,7 @@
fdevent_run_on_main_thread([packet, t]() { handle_packet(packet, t); });
return true;
});
- t->connection->SetErrorCallback([t](Connection*, const std::string& error) {
+ t->connection()->SetErrorCallback([t](Connection*, const std::string& error) {
D("%s: connection terminated: %s", t->serial, error.c_str());
fdevent_run_on_main_thread([t]() {
handle_offline(t);
@@ -539,7 +539,7 @@
});
});
- t->connection->Start();
+ t->connection()->Start();
#if ADB_HOST
send_connect(t);
#endif
@@ -608,7 +608,7 @@
t->ref_count--;
if (t->ref_count == 0) {
D("transport: %s unref (kicking and closing)", t->serial);
- t->connection->Stop();
+ t->connection()->Stop();
remove_transport(t);
} else {
D("transport: %s unref (count=%zu)", t->serial, t->ref_count);
@@ -758,14 +758,14 @@
}
int atransport::Write(apacket* p) {
- return this->connection->Write(std::unique_ptr<apacket>(p)) ? 0 : -1;
+ return this->connection()->Write(std::unique_ptr<apacket>(p)) ? 0 : -1;
}
void atransport::Kick() {
if (!kicked_) {
D("kicking transport %s", this->serial);
kicked_ = true;
- this->connection->Stop();
+ this->connection()->Stop();
}
}
@@ -778,6 +778,11 @@
connection_state_ = state;
}
+void atransport::SetConnection(std::unique_ptr<Connection> connection) {
+ std::lock_guard<std::mutex> lock(mutex_);
+ connection_ = std::shared_ptr<Connection>(std::move(connection));
+}
+
std::string atransport::connection_state_name() const {
ConnectionState state = GetConnectionState();
switch (state) {
@@ -1094,8 +1099,9 @@
void unregister_usb_transport(usb_handle* usb) {
std::lock_guard<std::recursive_mutex> lock(transport_lock);
transport_list.remove_if([usb](atransport* t) {
- if (auto connection = dynamic_cast<UsbConnection*>(t->connection.get())) {
- return connection->handle_ == usb && t->GetConnectionState() == kCsNoPerm;
+ auto connection = t->connection();
+ if (auto usb_connection = dynamic_cast<UsbConnection*>(connection.get())) {
+ return usb_connection->handle_ == usb && t->GetConnectionState() == kCsNoPerm;
}
return false;
});