fastboot: get rid of manual transport memory management

Existing code has transport memory leaks. Use smart pointers
for transport to get rid of those cases and manual memory
management

Test: atest fastboot_test
Test: manually checked transport isn't leaking anymore
Bug: 296629925
Change-Id: Ifdf162d5084f61ae5c1d2b56a897464af58100da
Signed-off-by: Dmitrii Merkurev <dimorinny@google.com>
diff --git a/fastboot/fastboot.cpp b/fastboot/fastboot.cpp
index 71a228e..fa21ab7 100644
--- a/fastboot/fastboot.cpp
+++ b/fastboot/fastboot.cpp
@@ -350,23 +350,22 @@
 //
 // The returned Transport is a singleton, so multiple calls to this function will return the same
 // object, and the caller should not attempt to delete the returned Transport.
-static Transport* open_device(const char* local_serial, bool wait_for_device = true,
-                              bool announce = true) {
+static std::unique_ptr<Transport> open_device(const char* local_serial,
+                                              bool wait_for_device = true,
+                                              bool announce = true) {
     const Result<NetworkSerial, FastbootError> network_serial = ParseNetworkSerial(local_serial);
 
-    Transport* transport = nullptr;
+    std::unique_ptr<Transport> transport;
     while (true) {
         if (network_serial.ok()) {
             std::string error;
             if (network_serial->protocol == Socket::Protocol::kTcp) {
-                transport = tcp::Connect(network_serial->address, network_serial->port, &error)
-                                    .release();
+                transport = tcp::Connect(network_serial->address, network_serial->port, &error);
             } else if (network_serial->protocol == Socket::Protocol::kUdp) {
-                transport = udp::Connect(network_serial->address, network_serial->port, &error)
-                                    .release();
+                transport = udp::Connect(network_serial->address, network_serial->port, &error);
             }
 
-            if (transport == nullptr && announce) {
+            if (!transport && announce) {
                 LOG(ERROR) << "error: " << error;
             }
         } else if (network_serial.error().code() ==
@@ -378,12 +377,12 @@
             Expect(network_serial);
         }
 
-        if (transport != nullptr) {
+        if (transport) {
             return transport;
         }
 
         if (!wait_for_device) {
-            return nullptr;
+            return transport;
         }
 
         if (announce) {
@@ -394,9 +393,9 @@
     }
 }
 
-static Transport* NetworkDeviceConnected(bool print = false) {
-    Transport* transport = nullptr;
-    Transport* result = nullptr;
+static std::unique_ptr<Transport> NetworkDeviceConnected(bool print = false) {
+    std::unique_ptr<Transport> transport;
+    std::unique_ptr<Transport> result;
 
     ConnectedDevicesStorage storage;
     std::set<std::string> devices;
@@ -409,11 +408,11 @@
         transport = open_device(device.c_str(), false, false);
 
         if (print) {
-            PrintDevice(device.c_str(), transport == nullptr ? "offline" : "fastboot");
+            PrintDevice(device.c_str(), transport ? "offline" : "fastboot");
         }
 
-        if (transport != nullptr) {
-            result = transport;
+        if (transport) {
+            result = std::move(transport);
         }
     }
 
@@ -431,21 +430,21 @@
 //
 // The returned Transport is a singleton, so multiple calls to this function will return the same
 // object, and the caller should not attempt to delete the returned Transport.
-static Transport* open_device() {
+static std::unique_ptr<Transport> open_device() {
     if (serial != nullptr) {
         return open_device(serial);
     }
 
     bool announce = true;
-    Transport* transport = nullptr;
+    std::unique_ptr<Transport> transport;
     while (true) {
         transport = usb_open(match_fastboot(nullptr));
-        if (transport != nullptr) {
+        if (transport) {
             return transport;
         }
 
         transport = NetworkDeviceConnected();
-        if (transport != nullptr) {
+        if (transport) {
             return transport;
         }
 
@@ -455,6 +454,8 @@
         }
         std::this_thread::sleep_for(std::chrono::seconds(1));
     }
+
+    return transport;
 }
 
 static int Connect(int argc, char* argv[]) {
@@ -466,8 +467,7 @@
     const char* local_serial = *argv;
     Expect(ParseNetworkSerial(local_serial));
 
-    const Transport* transport = open_device(local_serial, false);
-    if (transport == nullptr) {
+    if (!open_device(local_serial, false)) {
         return 1;
     }
 
@@ -531,6 +531,7 @@
     usb_open(list_devices_callback);
     NetworkDeviceConnected(/* print */ true);
 }
+
 void syntax_error(const char* fmt, ...) {
     fprintf(stderr, "fastboot: usage: ");
 
@@ -1547,9 +1548,7 @@
 
 void reboot_to_userspace_fastboot() {
     fb->RebootTo("fastboot");
-
-    auto* old_transport = fb->set_transport(nullptr);
-    delete old_transport;
+    fb->set_transport(nullptr);
 
     // Give the current connection time to close.
     std::this_thread::sleep_for(std::chrono::seconds(1));
@@ -2377,8 +2376,8 @@
         return show_help();
     }
 
-    Transport* transport = open_device();
-    if (transport == nullptr) {
+    std::unique_ptr<Transport> transport = open_device();
+    if (!transport) {
         return 1;
     }
     fastboot::DriverCallbacks driver_callbacks = {
@@ -2388,7 +2387,7 @@
             .text = TextMessage,
     };
 
-    fastboot::FastBootDriver fastboot_driver(transport, driver_callbacks, false);
+    fastboot::FastBootDriver fastboot_driver(std::move(transport), driver_callbacks, false);
     fb = &fastboot_driver;
     fp->fb = &fastboot_driver;
 
@@ -2633,9 +2632,6 @@
     }
     fprintf(stderr, "Finished. Total time: %.3fs\n", (now() - start));
 
-    auto* old_transport = fb->set_transport(nullptr);
-    delete old_transport;
-
     return 0;
 }
 
diff --git a/fastboot/fastboot_driver.cpp b/fastboot/fastboot_driver.cpp
index 9770ab2..e5ef66b 100644
--- a/fastboot/fastboot_driver.cpp
+++ b/fastboot/fastboot_driver.cpp
@@ -58,9 +58,10 @@
 namespace fastboot {
 
 /*************************** PUBLIC *******************************/
-FastBootDriver::FastBootDriver(Transport* transport, DriverCallbacks driver_callbacks,
+FastBootDriver::FastBootDriver(std::unique_ptr<Transport> transport,
+                               DriverCallbacks driver_callbacks,
                                bool no_checks)
-    : transport_(transport),
+    : transport_(std::move(transport)),
       prolog_(std::move(driver_callbacks.prolog)),
       epilog_(std::move(driver_callbacks.epilog)),
       info_(std::move(driver_callbacks.info)),
@@ -627,9 +628,8 @@
     return 0;
 }
 
-Transport* FastBootDriver::set_transport(Transport* transport) {
-    std::swap(transport_, transport);
-    return transport;
+void FastBootDriver::set_transport(std::unique_ptr<Transport> transport) {
+    transport_ = std::move(transport);
 }
 
 }  // End namespace fastboot
diff --git a/fastboot/fastboot_driver.h b/fastboot/fastboot_driver.h
index 8774ead..30298cb 100644
--- a/fastboot/fastboot_driver.h
+++ b/fastboot/fastboot_driver.h
@@ -30,6 +30,7 @@
 #include <deque>
 #include <functional>
 #include <limits>
+#include <memory>
 #include <string>
 #include <vector>
 
@@ -63,7 +64,7 @@
     static constexpr uint32_t MAX_DOWNLOAD_SIZE = std::numeric_limits<uint32_t>::max();
     static constexpr size_t TRANSPORT_CHUNK_SIZE = 1024;
 
-    FastBootDriver(Transport* transport, DriverCallbacks driver_callbacks = {},
+    FastBootDriver(std::unique_ptr<Transport> transport, DriverCallbacks driver_callbacks = {},
                    bool no_checks = false);
     ~FastBootDriver();
 
@@ -124,9 +125,7 @@
     std::string Error();
     RetCode WaitForDisconnect() override;
 
-    // Note: set_transport will return the previous transport.
-    Transport* set_transport(Transport* transport);
-    Transport* transport() const { return transport_; }
+    void set_transport(std::unique_ptr<Transport> transport);
 
     RetCode RawCommand(const std::string& cmd, const std::string& message,
                        std::string* response = nullptr, std::vector<std::string>* info = nullptr,
@@ -143,7 +142,7 @@
 
     std::string ErrnoStr(const std::string& msg);
 
-    Transport* transport_;
+    std::unique_ptr<Transport> transport_;
 
   private:
     RetCode SendBuffer(android::base::borrowed_fd fd, size_t size);
diff --git a/fastboot/fastboot_driver_test.cpp b/fastboot/fastboot_driver_test.cpp
index 6f6cf8c..d2033b0 100644
--- a/fastboot/fastboot_driver_test.cpp
+++ b/fastboot/fastboot_driver_test.cpp
@@ -16,6 +16,7 @@
 
 #include "fastboot_driver.h"
 
+#include <memory>
 #include <optional>
 
 #include <gtest/gtest.h>
@@ -30,13 +31,14 @@
 };
 
 TEST_F(DriverTest, GetVar) {
-    MockTransport transport;
-    FastBootDriver driver(&transport);
+    std::unique_ptr<MockTransport> transport_pointer = std::make_unique<MockTransport>();
+    MockTransport* transport = transport_pointer.get();
+    FastBootDriver driver(std::move(transport_pointer));
 
-    EXPECT_CALL(transport, Write(_, _))
+    EXPECT_CALL(*transport, Write(_, _))
             .With(AllArgs(RawData("getvar:version")))
             .WillOnce(ReturnArg<1>());
-    EXPECT_CALL(transport, Read(_, _)).WillOnce(Invoke(CopyData("OKAY0.4")));
+    EXPECT_CALL(*transport, Read(_, _)).WillOnce(Invoke(CopyData("OKAY0.4")));
 
     std::string output;
     ASSERT_EQ(driver.GetVar("version", &output), SUCCESS) << driver.Error();
@@ -44,14 +46,15 @@
 }
 
 TEST_F(DriverTest, InfoMessage) {
-    MockTransport transport;
-    FastBootDriver driver(&transport);
+    std::unique_ptr<MockTransport> transport_pointer = std::make_unique<MockTransport>();
+    MockTransport* transport = transport_pointer.get();
+    FastBootDriver driver(std::move(transport_pointer));
 
-    EXPECT_CALL(transport, Write(_, _))
+    EXPECT_CALL(*transport, Write(_, _))
             .With(AllArgs(RawData("oem dmesg")))
             .WillOnce(ReturnArg<1>());
-    EXPECT_CALL(transport, Read(_, _)).WillOnce(Invoke(CopyData("INFOthis is an info line")));
-    EXPECT_CALL(transport, Read(_, _)).WillOnce(Invoke(CopyData("OKAY")));
+    EXPECT_CALL(*transport, Read(_, _)).WillOnce(Invoke(CopyData("INFOthis is an info line")));
+    EXPECT_CALL(*transport, Read(_, _)).WillOnce(Invoke(CopyData("OKAY")));
 
     std::vector<std::string> info;
     ASSERT_EQ(driver.RawCommand("oem dmesg", "", nullptr, &info), SUCCESS) << driver.Error();
@@ -60,28 +63,29 @@
 }
 
 TEST_F(DriverTest, TextMessage) {
-    MockTransport transport;
     std::string text;
+    std::unique_ptr<MockTransport> transport_pointer = std::make_unique<MockTransport>();
+    MockTransport* transport = transport_pointer.get();
 
     DriverCallbacks callbacks{[](const std::string&) {}, [](int) {}, [](const std::string&) {},
                               [&text](const std::string& extra_text) { text += extra_text; }};
 
-    FastBootDriver driver(&transport, callbacks);
+    FastBootDriver driver(std::move(transport_pointer), callbacks);
 
-    EXPECT_CALL(transport, Write(_, _))
+    EXPECT_CALL(*transport, Write(_, _))
             .With(AllArgs(RawData("oem trusty runtest trusty.hwaes.bench")))
             .WillOnce(ReturnArg<1>());
-    EXPECT_CALL(transport, Read(_, _)).WillOnce(Invoke(CopyData("TEXTthis is a text line")));
-    EXPECT_CALL(transport, Read(_, _))
+    EXPECT_CALL(*transport, Read(_, _)).WillOnce(Invoke(CopyData("TEXTthis is a text line")));
+    EXPECT_CALL(*transport, Read(_, _))
             .WillOnce(Invoke(
                     CopyData("TEXT, albeit very long and split over multiple TEXT messages.")));
-    EXPECT_CALL(transport, Read(_, _))
+    EXPECT_CALL(*transport, Read(_, _))
             .WillOnce(Invoke(CopyData("TEXT Indeed we can do that now with a TEXT message whenever "
                                       "we feel like it.")));
-    EXPECT_CALL(transport, Read(_, _))
+    EXPECT_CALL(*transport, Read(_, _))
             .WillOnce(Invoke(CopyData("TEXT Isn't that truly super cool?")));
 
-    EXPECT_CALL(transport, Read(_, _)).WillOnce(Invoke(CopyData("OKAY")));
+    EXPECT_CALL(*transport, Read(_, _)).WillOnce(Invoke(CopyData("OKAY")));
 
     std::vector<std::string> info;
     ASSERT_EQ(driver.RawCommand("oem trusty runtest trusty.hwaes.bench", "", nullptr, &info),
diff --git a/fastboot/fuzzy_fastboot/fixtures.cpp b/fastboot/fuzzy_fastboot/fixtures.cpp
index 9b5e5f7..94a53ed 100644
--- a/fastboot/fuzzy_fastboot/fixtures.cpp
+++ b/fastboot/fuzzy_fastboot/fixtures.cpp
@@ -128,7 +128,7 @@
             return MatchFastboot(info, device_serial);
         };
         for (int i = 0; i < MAX_USB_TRIES && !transport; i++) {
-            std::unique_ptr<UsbTransport> usb(usb_open(matcher, USB_TIMEOUT));
+            std::unique_ptr<UsbTransport> usb = usb_open(matcher, USB_TIMEOUT);
             if (usb)
                 transport = std::unique_ptr<TransportSniffer>(
                         new TransportSniffer(std::move(usb), serial_port));
@@ -143,7 +143,7 @@
     } else {
         ASSERT_EQ(device_path, cb_scratch);  // The path can not change
     }
-    fb = std::unique_ptr<FastBootDriver>(new FastBootDriver(transport.get(), {}, true));
+    fb = std::unique_ptr<FastBootDriver>(new FastBootDriver(std::move(transport), {}, true));
     // No error checking since non-A/B devices may not support the command
     fb->GetVar("current-slot", &initial_slot);
 }
@@ -200,7 +200,7 @@
     if (IsFastbootOverTcp()) {
         ConnectTcpFastbootDevice();
         device_path = cb_scratch;
-        fb = std::unique_ptr<FastBootDriver>(new FastBootDriver(transport.get(), {}, true));
+        fb = std::unique_ptr<FastBootDriver>(new FastBootDriver(std::move(transport), {}, true));
         return;
     }
 
@@ -212,7 +212,7 @@
         return MatchFastboot(info, device_serial);
     };
     while (!transport) {
-        std::unique_ptr<UsbTransport> usb(usb_open(matcher, USB_TIMEOUT));
+        std::unique_ptr<UsbTransport> usb = usb_open(matcher, USB_TIMEOUT);
         if (usb) {
             transport = std::unique_ptr<TransportSniffer>(
                     new TransportSniffer(std::move(usb), serial_port));
@@ -220,7 +220,7 @@
         std::this_thread::sleep_for(1s);
     }
     device_path = cb_scratch;
-    fb = std::unique_ptr<FastBootDriver>(new FastBootDriver(transport.get(), {}, true));
+    fb = std::unique_ptr<FastBootDriver>(new FastBootDriver(std::move(transport), {}, true));
 }
 
 void FastBootTest::SetLockState(bool unlock, bool assert_change) {
diff --git a/fastboot/fuzzy_fastboot/main.cpp b/fastboot/fuzzy_fastboot/main.cpp
index e635937..43aaec9 100644
--- a/fastboot/fuzzy_fastboot/main.cpp
+++ b/fastboot/fuzzy_fastboot/main.cpp
@@ -166,16 +166,15 @@
     const auto matcher = [](usb_ifc_info* info) -> int {
         return FastBootTest::MatchFastboot(info, fastboot::FastBootTest::device_serial);
     };
-    Transport* transport = nullptr;
+    std::unique_ptr<Transport> transport;
     for (int i = 0; i < FastBootTest::MAX_USB_TRIES && !transport; i++) {
         transport = usb_open(matcher);
         std::this_thread::sleep_for(std::chrono::milliseconds(10));
     }
-    ASSERT_NE(transport, nullptr) << "Could not find the fastboot device after: "
-                                  << 10 * FastBootTest::MAX_USB_TRIES << "ms";
+    ASSERT_NE(transport.get(), nullptr) << "Could not find the fastboot device after: "
+                                        << 10 * FastBootTest::MAX_USB_TRIES << "ms";
     if (transport) {
         transport->Close();
-        delete transport;
     }
 }
 
@@ -1897,7 +1896,7 @@
         const auto matcher = [](usb_ifc_info* info) -> int {
             return fastboot::FastBootTest::MatchFastboot(info, fastboot::FastBootTest::device_serial);
         };
-        Transport* transport = nullptr;
+        std::unique_ptr<Transport> transport;
         while (!transport) {
             transport = usb_open(matcher);
             std::this_thread::sleep_for(std::chrono::milliseconds(10));
diff --git a/fastboot/usb.h b/fastboot/usb.h
index 69581ab..d85cb81 100644
--- a/fastboot/usb.h
+++ b/fastboot/usb.h
@@ -29,6 +29,7 @@
 #pragma once
 
 #include <functional>
+#include <memory>
 
 #include "transport.h"
 
@@ -66,4 +67,4 @@
 typedef std::function<int(usb_ifc_info*)> ifc_match_func;
 
 // 0 is non blocking
-UsbTransport* usb_open(ifc_match_func callback, uint32_t timeout_ms = 0);
+std::unique_ptr<UsbTransport> usb_open(ifc_match_func callback, uint32_t timeout_ms = 0);
diff --git a/fastboot/usb_linux.cpp b/fastboot/usb_linux.cpp
index 964488c..37bb304 100644
--- a/fastboot/usb_linux.cpp
+++ b/fastboot/usb_linux.cpp
@@ -503,9 +503,15 @@
     return 0;
 }
 
-UsbTransport* usb_open(ifc_match_func callback, uint32_t timeout_ms) {
+std::unique_ptr<UsbTransport> usb_open(ifc_match_func callback, uint32_t timeout_ms) {
+    std::unique_ptr<UsbTransport> result;
     std::unique_ptr<usb_handle> handle = find_usb_device("/sys/bus/usb/devices", callback);
-    return handle ? new LinuxUsbTransport(std::move(handle), timeout_ms) : nullptr;
+
+    if (handle) {
+        result = std::make_unique<LinuxUsbTransport>(std::move(handle), timeout_ms);
+    }
+
+    return result;
 }
 
 /* Wait for the system to notice the device is gone, so that a subsequent
diff --git a/fastboot/usb_osx.cpp b/fastboot/usb_osx.cpp
index 8b852f5..28300b2 100644
--- a/fastboot/usb_osx.cpp
+++ b/fastboot/usb_osx.cpp
@@ -469,16 +469,20 @@
 /*
  * Definitions of this file's public functions.
  */
-
-UsbTransport* usb_open(ifc_match_func callback, uint32_t timeout_ms) {
+std::unique_ptr<UsbTransport> usb_open(ifc_match_func callback, uint32_t timeout_ms) {
+    std::unique_ptr<UsbTransport> result;
     std::unique_ptr<usb_handle> handle;
 
     if (init_usb(callback, &handle) < 0) {
         /* Something went wrong initializing USB. */
-        return nullptr;
+        return result;
     }
 
-    return new OsxUsbTransport(std::move(handle), timeout_ms);
+    if (handle) {
+        result = std::make_unique<OsxUsbTransport>(std::move(handle), timeout_ms);
+    }
+
+    return result;
 }
 
 OsxUsbTransport::~OsxUsbTransport() {
diff --git a/fastboot/usb_windows.cpp b/fastboot/usb_windows.cpp
index 67bf8a3..56a6e7d 100644
--- a/fastboot/usb_windows.cpp
+++ b/fastboot/usb_windows.cpp
@@ -381,7 +381,13 @@
     return handle;
 }
 
-UsbTransport* usb_open(ifc_match_func callback, uint32_t) {
+std::unique_ptr<UsbTransport> usb_open(ifc_match_func callback, uint32_t) {
+    std::unique_ptr<UsbTransport> result;
     std::unique_ptr<usb_handle> handle = find_usb_device(callback);
-    return handle ? new WindowsUsbTransport(std::move(handle)) : nullptr;
+
+    if (handle) {
+        result = std::make_unique<WindowsUsbTransport>(std::move(handle));
+    }
+
+    return result;
 }