libcutils/fastboot: improve multi-buffer write.

Fixes libcutils multi-buffer write interface to be more friendly and
hooks into it from the fastboot Socket class.

Bug: http://b/26558551
Change-Id: Ibb3a8428fc379755602de52722c1260f9e345bc0
diff --git a/fastboot/Android.mk b/fastboot/Android.mk
index 65f4e01..11d769b 100644
--- a/fastboot/Android.mk
+++ b/fastboot/Android.mk
@@ -65,6 +65,7 @@
     libdiagnose_usb \
     libbase \
     libcutils \
+    libgtest_host \
 
 # libf2fs_dlutils_host will dlopen("libf2fs_fmt_host_dyn")
 LOCAL_CFLAGS_linux := -DUSE_F2FS
diff --git a/fastboot/socket.cpp b/fastboot/socket.cpp
index 0a3ddfa..d49f47f 100644
--- a/fastboot/socket.cpp
+++ b/fastboot/socket.cpp
@@ -89,7 +89,8 @@
 
     UdpSocket(Type type, cutils_socket_t sock);
 
-    ssize_t Send(const void* data, size_t length) override;
+    bool Send(const void* data, size_t length) override;
+    bool Send(std::vector<cutils_socket_buffer_t> buffers) override;
     ssize_t Receive(void* data, size_t length, int timeout_ms) override;
 
   private:
@@ -109,9 +110,20 @@
     }
 }
 
-ssize_t UdpSocket::Send(const void* data, size_t length) {
+bool UdpSocket::Send(const void* data, size_t length) {
     return TEMP_FAILURE_RETRY(sendto(sock_, reinterpret_cast<const char*>(data), length, 0,
-                                     reinterpret_cast<sockaddr*>(addr_.get()), addr_size_));
+                                     reinterpret_cast<sockaddr*>(addr_.get()), addr_size_)) ==
+           static_cast<ssize_t>(length);
+}
+
+bool UdpSocket::Send(std::vector<cutils_socket_buffer_t> buffers) {
+    size_t total_length = 0;
+    for (const auto& buffer : buffers) {
+        total_length += buffer.length;
+    }
+
+    return TEMP_FAILURE_RETRY(socket_send_buffers_function_(
+                   sock_, buffers.data(), buffers.size())) == static_cast<ssize_t>(total_length);
 }
 
 ssize_t UdpSocket::Receive(void* data, size_t length, int timeout_ms) {
@@ -135,7 +147,8 @@
   public:
     TcpSocket(cutils_socket_t sock) : Socket(sock) {}
 
-    ssize_t Send(const void* data, size_t length) override;
+    bool Send(const void* data, size_t length) override;
+    bool Send(std::vector<cutils_socket_buffer_t> buffers) override;
     ssize_t Receive(void* data, size_t length, int timeout_ms) override;
 
     std::unique_ptr<Socket> Accept() override;
@@ -144,23 +157,52 @@
     DISALLOW_COPY_AND_ASSIGN(TcpSocket);
 };
 
-ssize_t TcpSocket::Send(const void* data, size_t length) {
-    size_t total = 0;
+bool TcpSocket::Send(const void* data, size_t length) {
+    while (length > 0) {
+        ssize_t sent =
+                TEMP_FAILURE_RETRY(send(sock_, reinterpret_cast<const char*>(data), length, 0));
 
-    while (total < length) {
-        ssize_t bytes = TEMP_FAILURE_RETRY(
-                send(sock_, reinterpret_cast<const char*>(data) + total, length - total, 0));
-
-        if (bytes == -1) {
-            if (total == 0) {
-                return -1;
-            }
-            break;
+        if (sent == -1) {
+            return false;
         }
-        total += bytes;
+        length -= sent;
     }
 
-    return total;
+    return true;
+}
+
+bool TcpSocket::Send(std::vector<cutils_socket_buffer_t> buffers) {
+    while (!buffers.empty()) {
+        ssize_t sent = TEMP_FAILURE_RETRY(
+                socket_send_buffers_function_(sock_, buffers.data(), buffers.size()));
+
+        if (sent == -1) {
+            return false;
+        }
+
+        // Adjust the buffers to skip past the bytes we've just sent.
+        auto iter = buffers.begin();
+        while (sent > 0) {
+            if (iter->length > static_cast<size_t>(sent)) {
+                // Incomplete buffer write; adjust the buffer to point to the next byte to send.
+                iter->length -= sent;
+                iter->data = reinterpret_cast<const char*>(iter->data) + sent;
+                break;
+            }
+
+            // Complete buffer write; move on to the next buffer.
+            sent -= iter->length;
+            ++iter;
+        }
+
+        // Shortcut the common case: we've written everything remaining.
+        if (iter == buffers.end()) {
+            break;
+        }
+        buffers.erase(buffers.begin(), iter);
+    }
+
+    return true;
 }
 
 ssize_t TcpSocket::Receive(void* data, size_t length, int timeout_ms) {
diff --git a/fastboot/socket.h b/fastboot/socket.h
index a7481db..c0bd7c9 100644
--- a/fastboot/socket.h
+++ b/fastboot/socket.h
@@ -33,11 +33,15 @@
 #ifndef SOCKET_H_
 #define SOCKET_H_
 
+#include <functional>
 #include <memory>
 #include <string>
+#include <utility>
+#include <vector>
 
 #include <android-base/macros.h>
 #include <cutils/sockets.h>
+#include <gtest/gtest_prod.h>
 
 // Socket interface to be implemented for each platform.
 class Socket {
@@ -64,8 +68,17 @@
     virtual ~Socket();
 
     // Sends |length| bytes of |data|. For TCP sockets this will continue trying to send until all
-    // bytes are transmitted. Returns the number of bytes actually sent or -1 on error.
-    virtual ssize_t Send(const void* data, size_t length) = 0;
+    // bytes are transmitted. Returns true on success.
+    virtual bool Send(const void* data, size_t length) = 0;
+
+    // Sends |buffers| using multi-buffer write, which can be significantly faster than making
+    // multiple calls. For UDP sockets |buffers| are all combined into a single datagram; for
+    // TCP sockets this will continue sending until all buffers are fully transmitted. Returns true
+    // on success.
+    //
+    // Note: This is non-functional for UDP server Sockets because it's not currently needed and
+    // would require an additional sendto() variation of multi-buffer write.
+    virtual bool Send(std::vector<cutils_socket_buffer_t> buffers) = 0;
 
     // Waits up to |timeout_ms| to receive up to |length| bytes of data. |timout_ms| of 0 will
     // block forever. Returns the number of bytes received or -1 on error/timeout. On timeout
@@ -94,9 +107,17 @@
 
     cutils_socket_t sock_ = INVALID_SOCKET;
 
+    // Non-class functions we want to override during tests to verify functionality. Implementation
+    // should call this rather than using socket_send_buffers() directly.
+    std::function<ssize_t(cutils_socket_t, cutils_socket_buffer_t*, size_t)>
+            socket_send_buffers_function_ = &socket_send_buffers;
+
   private:
     int receive_timeout_ms_ = 0;
 
+    FRIEND_TEST(SocketTest, TestTcpSendBuffers);
+    FRIEND_TEST(SocketTest, TestUdpSendBuffers);
+
     DISALLOW_COPY_AND_ASSIGN(Socket);
 };
 
diff --git a/fastboot/socket_mock.cpp b/fastboot/socket_mock.cpp
index 8fea554..bcb91ec 100644
--- a/fastboot/socket_mock.cpp
+++ b/fastboot/socket_mock.cpp
@@ -38,26 +38,35 @@
     }
 }
 
-ssize_t SocketMock::Send(const void* data, size_t length) {
+bool SocketMock::Send(const void* data, size_t length) {
     if (events_.empty()) {
         ADD_FAILURE() << "Send() was called when no message was expected";
-        return -1;
+        return false;
     }
 
     if (events_.front().type != EventType::kSend) {
         ADD_FAILURE() << "Send() was called out-of-order";
-        return -1;
+        return false;
     }
 
     std::string message(reinterpret_cast<const char*>(data), length);
     if (events_.front().message != message) {
         ADD_FAILURE() << "Send() expected " << events_.front().message << ", but got " << message;
-        return -1;
+        return false;
     }
 
-    ssize_t return_value = events_.front().return_value;
     events_.pop();
-    return return_value;
+    return true;
+}
+
+// Mock out multi-buffer send to be one large send, since that's what it should looks like from
+// the user's perspective.
+bool SocketMock::Send(std::vector<cutils_socket_buffer_t> buffers) {
+    std::string data;
+    for (const auto& buffer : buffers) {
+        data.append(reinterpret_cast<const char*>(buffer.data), buffer.length);
+    }
+    return Send(data.data(), data.size());
 }
 
 ssize_t SocketMock::Receive(void* data, size_t length, int /*timeout_ms*/) {
@@ -106,13 +115,13 @@
 }
 
 void SocketMock::ExpectSend(std::string message) {
-    ssize_t return_value = message.length();
-    events_.push(Event(EventType::kSend, std::move(message), return_value, nullptr));
+    events_.push(Event(EventType::kSend, std::move(message), 0, nullptr));
 }
 
-void SocketMock::ExpectSendFailure(std::string message) {
-    events_.push(Event(EventType::kSend, std::move(message), -1, nullptr));
-}
+// TODO: make this properly return false to the caller.
+//void SocketMock::ExpectSendFailure(std::string message) {
+//    events_.push(Event(EventType::kSend, std::move(message), 0, nullptr));
+//}
 
 void SocketMock::AddReceive(std::string message) {
     ssize_t return_value = message.length();
diff --git a/fastboot/socket_mock.h b/fastboot/socket_mock.h
index 3e62b33..c48aa7b 100644
--- a/fastboot/socket_mock.h
+++ b/fastboot/socket_mock.h
@@ -56,7 +56,8 @@
     SocketMock();
     ~SocketMock() override;
 
-    ssize_t Send(const void* data, size_t length) override;
+    bool Send(const void* data, size_t length) override;
+    bool Send(std::vector<cutils_socket_buffer_t> buffers) override;
     ssize_t Receive(void* data, size_t length, int timeout_ms) override;
     int Close() override;
     virtual std::unique_ptr<Socket> Accept();
@@ -64,9 +65,6 @@
     // Adds an expectation for Send().
     void ExpectSend(std::string message);
 
-    // Adds an expectation for Send() that returns -1.
-    void ExpectSendFailure(std::string message);
-
     // Adds data to provide for Receive().
     void AddReceive(std::string message);
 
diff --git a/fastboot/socket_test.cpp b/fastboot/socket_test.cpp
index 7bfe967..9365792 100644
--- a/fastboot/socket_test.cpp
+++ b/fastboot/socket_test.cpp
@@ -23,8 +23,10 @@
 #include "socket.h"
 #include "socket_mock.h"
 
-#include <gtest/gtest.h>
+#include <list>
+
 #include <gtest/gtest-spi.h>
+#include <gtest/gtest.h>
 
 enum { kTestTimeoutMs = 3000 };
 
@@ -59,7 +61,7 @@
 // Sends a string over a Socket. Returns true if the full string (without terminating char)
 // was sent.
 static bool SendString(Socket* sock, const std::string& message) {
-    return sock->Send(message.c_str(), message.length()) == static_cast<ssize_t>(message.length());
+    return sock->Send(message.c_str(), message.length());
 }
 
 // Receives a string from a Socket. Returns true if the full string (without terminating char)
@@ -123,6 +125,116 @@
     }
 }
 
+// Tests UDP multi-buffer send.
+TEST(SocketTest, TestUdpSendBuffers) {
+    std::unique_ptr<Socket> sock = Socket::NewServer(Socket::Protocol::kUdp, 0);
+    std::vector<std::string> data{"foo", "bar", "12345"};
+    std::vector<cutils_socket_buffer_t> buffers{{data[0].data(), data[0].length()},
+                                                {data[1].data(), data[1].length()},
+                                                {data[2].data(), data[2].length()}};
+    ssize_t mock_return_value = 0;
+
+    // Mock out socket_send_buffers() to verify we're sending in the correct buffers and
+    // return |mock_return_value|.
+    sock->socket_send_buffers_function_ = [&buffers, &mock_return_value](
+            cutils_socket_t /*cutils_sock*/, cutils_socket_buffer_t* sent_buffers,
+            size_t num_sent_buffers) -> ssize_t {
+        EXPECT_EQ(buffers.size(), num_sent_buffers);
+        for (size_t i = 0; i < num_sent_buffers; ++i) {
+            EXPECT_EQ(buffers[i].data, sent_buffers[i].data);
+            EXPECT_EQ(buffers[i].length, sent_buffers[i].length);
+        }
+        return mock_return_value;
+    };
+
+    mock_return_value = strlen("foobar12345");
+    EXPECT_TRUE(sock->Send(buffers));
+
+    mock_return_value -= 1;
+    EXPECT_FALSE(sock->Send(buffers));
+
+    mock_return_value = 0;
+    EXPECT_FALSE(sock->Send(buffers));
+
+    mock_return_value = -1;
+    EXPECT_FALSE(sock->Send(buffers));
+}
+
+// Tests TCP re-sending until socket_send_buffers() sends all data. This is a little complicated,
+// but the general idea is that we intercept calls to socket_send_buffers() using a lambda mock
+// function that simulates partial writes.
+TEST(SocketTest, TestTcpSendBuffers) {
+    std::unique_ptr<Socket> sock = Socket::NewServer(Socket::Protocol::kTcp, 0);
+    std::vector<std::string> data{"foo", "bar", "12345"};
+    std::vector<cutils_socket_buffer_t> buffers{{data[0].data(), data[0].length()},
+                                                {data[1].data(), data[1].length()},
+                                                {data[2].data(), data[2].length()}};
+
+    // Test breaking up the buffered send at various points.
+    std::list<std::string> test_sends[] = {
+            // Successes.
+            {"foobar12345"},
+            {"f", "oob", "ar12345"},
+            {"fo", "obar12", "345"},
+            {"foo", "bar12345"},
+            {"foob", "ar123", "45"},
+            {"f", "o", "o", "b", "a", "r", "1", "2", "3", "4", "5"},
+
+            // Failures.
+            {},
+            {"f"},
+            {"foo", "bar"},
+            {"fo", "obar12"},
+            {"foobar1234"}
+    };
+
+    for (auto& test : test_sends) {
+        ssize_t bytes_sent = 0;
+        bool expect_success = true;
+
+        // Create a mock function for custom socket_send_buffers() behavior. This function will
+        // check to make sure the input buffers start at the next unsent byte, then return the
+        // number of bytes indicated by the next entry in |test|.
+        sock->socket_send_buffers_function_ = [&bytes_sent, &data, &expect_success, &test](
+                cutils_socket_t /*cutils_sock*/, cutils_socket_buffer_t* buffers,
+                size_t num_buffers) -> ssize_t {
+            EXPECT_TRUE(num_buffers > 0);
+
+            // Failure case - pretend we errored out before sending all the buffers.
+            if (test.empty()) {
+                expect_success = false;
+                return -1;
+            }
+
+            // Count the bytes we've sent to find where the next buffer should start and how many
+            // bytes should be left in it.
+            size_t byte_count = bytes_sent, data_index = 0;
+            while (data_index < data.size()) {
+                if (byte_count >= data[data_index].length()) {
+                    byte_count -= data[data_index].length();
+                    ++data_index;
+                } else {
+                    break;
+                }
+            }
+            void* expected_next_byte = &data[data_index][byte_count];
+            size_t expected_next_size = data[data_index].length() - byte_count;
+
+            EXPECT_EQ(data.size() - data_index, num_buffers);
+            EXPECT_EQ(expected_next_byte, buffers[0].data);
+            EXPECT_EQ(expected_next_size, buffers[0].length);
+
+            std::string to_send = std::move(test.front());
+            test.pop_front();
+            bytes_sent += to_send.length();
+            return to_send.length();
+        };
+
+        EXPECT_EQ(expect_success, sock->Send(buffers));
+        EXPECT_TRUE(test.empty());
+    }
+}
+
 TEST(SocketMockTest, TestSendSuccess) {
     SocketMock mock;