libpdx_uds: Fix send/receive over socket to handle signal interrupts

Previous implementation of send/receive didn't account for the fact that
send/receive operation might be interrupted by a signal and transfer
fewer bytes than requested.

Fix this by repeatedly calling send/recv until all the requested data
is transferred over sockets.

Also added a number of unit tests for send/receive functions.

Bug: 37427314
Test: `m -j32` succeeds for Sailfish.
      `libpdx_uds_tests` pass on device

Change-Id: Ib8f78967af3c218d9f18fb3dfe8953c35800540b
diff --git a/libs/vr/libpdx_uds/Android.bp b/libs/vr/libpdx_uds/Android.bp
index a73ba34..f2bcc0c 100644
--- a/libs/vr/libpdx_uds/Android.bp
+++ b/libs/vr/libpdx_uds/Android.bp
@@ -35,10 +35,12 @@
         "-Werror",
     ],
     srcs: [
+        "ipc_helper_tests.cpp",
         "remote_method_tests.cpp",
         "service_framework_tests.cpp",
     ],
     static_libs: [
+        "libgmock",
         "libpdx_uds",
         "libpdx",
     ],
diff --git a/libs/vr/libpdx_uds/ipc_helper.cpp b/libs/vr/libpdx_uds/ipc_helper.cpp
index d604f62..b675894 100644
--- a/libs/vr/libpdx_uds/ipc_helper.cpp
+++ b/libs/vr/libpdx_uds/ipc_helper.cpp
@@ -18,6 +18,150 @@
 namespace pdx {
 namespace uds {
 
+namespace {
+
+// Default implementations of Send/Receive interfaces to use standard socket
+// send/sendmsg/recv/recvmsg functions.
+class SocketSender : public SendInterface {
+ public:
+  ssize_t Send(int socket_fd, const void* data, size_t size,
+               int flags) override {
+    return send(socket_fd, data, size, flags);
+  }
+  ssize_t SendMessage(int socket_fd, const msghdr* msg, int flags) override {
+    return sendmsg(socket_fd, msg, flags);
+  }
+} g_socket_sender;
+
+class SocketReceiver : public RecvInterface {
+ public:
+  ssize_t Receive(int socket_fd, void* data, size_t size, int flags) override {
+    return recv(socket_fd, data, size, flags);
+  }
+  ssize_t ReceiveMessage(int socket_fd, msghdr* msg, int flags) override {
+    return recvmsg(socket_fd, msg, flags);
+  }
+} g_socket_receiver;
+
+}  // anonymous namespace
+
+// Helper wrappers around send()/sendmsg() which repeat send() calls on data
+// that was not sent with the initial call to send/sendmsg. This is important to
+// handle transmissions interrupted by signals.
+Status<void> SendAll(SendInterface* sender, const BorrowedHandle& socket_fd,
+                     const void* data, size_t size) {
+  Status<void> ret;
+  const uint8_t* ptr = static_cast<const uint8_t*>(data);
+  while (size > 0) {
+    ssize_t size_written =
+        RETRY_EINTR(sender->Send(socket_fd.Get(), ptr, size, MSG_NOSIGNAL));
+    if (size_written < 0) {
+      ret.SetError(errno);
+      ALOGE("SendAll: Failed to send data over socket: %s",
+            ret.GetErrorMessage().c_str());
+      break;
+    }
+    size -= size_written;
+    ptr += size_written;
+  }
+  return ret;
+}
+
+Status<void> SendMsgAll(SendInterface* sender, const BorrowedHandle& socket_fd,
+                        const msghdr* msg) {
+  Status<void> ret;
+  ssize_t sent_size =
+      RETRY_EINTR(sender->SendMessage(socket_fd.Get(), msg, MSG_NOSIGNAL));
+  if (sent_size < 0) {
+    ret.SetError(errno);
+    ALOGE("SendMsgAll: Failed to send data over socket: %s",
+          ret.GetErrorMessage().c_str());
+    return ret;
+  }
+
+  ssize_t chunk_start_offset = 0;
+  for (size_t i = 0; i < msg->msg_iovlen; i++) {
+    ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
+    if (sent_size < chunk_end_offset) {
+      size_t offset_within_chunk = sent_size - chunk_start_offset;
+      size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
+      const uint8_t* chunk_base =
+          static_cast<const uint8_t*>(msg->msg_iov[i].iov_base);
+      ret = SendAll(sender, socket_fd, chunk_base + offset_within_chunk,
+                    data_size);
+      if (!ret)
+        break;
+      sent_size += data_size;
+    }
+    chunk_start_offset = chunk_end_offset;
+  }
+  return ret;
+}
+
+// Helper wrappers around recv()/recvmsg() which repeat recv() calls on data
+// that was not received with the initial call to recvmsg(). This is important
+// to handle transmissions interrupted by signals as well as the case when
+// initial data did not arrive in a single chunk over the socket (e.g. socket
+// buffer was full at the time of transmission, and only portion of initial
+// message was sent and the rest was blocked until the buffer was cleared by the
+// receiving side).
+Status<void> RecvMsgAll(RecvInterface* receiver,
+                        const BorrowedHandle& socket_fd, msghdr* msg) {
+  Status<void> ret;
+  ssize_t size_read = RETRY_EINTR(receiver->ReceiveMessage(
+      socket_fd.Get(), msg, MSG_WAITALL | MSG_CMSG_CLOEXEC));
+  if (size_read < 0) {
+    ret.SetError(errno);
+    ALOGE("RecvMsgAll: Failed to receive data from socket: %s",
+          ret.GetErrorMessage().c_str());
+    return ret;
+  } else if (size_read == 0) {
+    ret.SetError(ESHUTDOWN);
+    ALOGW("RecvMsgAll: Socket has been shut down");
+    return ret;
+  }
+
+  ssize_t chunk_start_offset = 0;
+  for (size_t i = 0; i < msg->msg_iovlen; i++) {
+    ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
+    if (size_read < chunk_end_offset) {
+      size_t offset_within_chunk = size_read - chunk_start_offset;
+      size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
+      uint8_t* chunk_base = static_cast<uint8_t*>(msg->msg_iov[i].iov_base);
+      ret = RecvAll(receiver, socket_fd, chunk_base + offset_within_chunk,
+                    data_size);
+      if (!ret)
+        break;
+      size_read += data_size;
+    }
+    chunk_start_offset = chunk_end_offset;
+  }
+  return ret;
+}
+
+Status<void> RecvAll(RecvInterface* receiver, const BorrowedHandle& socket_fd,
+                     void* data, size_t size) {
+  Status<void> ret;
+  uint8_t* ptr = static_cast<uint8_t*>(data);
+  while (size > 0) {
+    ssize_t size_read = RETRY_EINTR(receiver->Receive(
+        socket_fd.Get(), ptr, size, MSG_WAITALL | MSG_CMSG_CLOEXEC));
+    if (size_read < 0) {
+      ret.SetError(errno);
+      ALOGE("RecvAll: Failed to receive data from socket: %s",
+            ret.GetErrorMessage().c_str());
+      break;
+    } else if (size_read == 0) {
+      ret.SetError(ESHUTDOWN);
+      ALOGW("RecvAll: Socket has been shut down");
+      break;
+    }
+    size -= size_read;
+    ptr += size_read;
+  }
+  return ret;
+}
+
 uint32_t kMagicPreamble = 0x7564736d;  // 'udsm'.
 
 struct MessagePreamble {
@@ -32,17 +176,14 @@
 
 Status<void> SendPayload::Send(const BorrowedHandle& socket_fd,
                                const ucred* cred) {
+  SendInterface* sender = sender_ ? sender_ : &g_socket_sender;
   MessagePreamble preamble;
   preamble.magic = kMagicPreamble;
   preamble.data_size = buffer_.size();
   preamble.fd_count = file_handles_.size();
-
-  ssize_t ret = RETRY_EINTR(
-      send(socket_fd.Get(), &preamble, sizeof(preamble), MSG_NOSIGNAL));
-  if (ret < 0)
-    return ErrorStatus(errno);
-  if (ret != sizeof(preamble))
-    return ErrorStatus(EIO);
+  Status<void> ret = SendAll(sender, socket_fd, &preamble, sizeof(preamble));
+  if (!ret)
+    return ret;
 
   msghdr msg = {};
   iovec recv_vect = {buffer_.data(), buffer_.size()};
@@ -72,12 +213,7 @@
     }
   }
 
-  ret = RETRY_EINTR(sendmsg(socket_fd.Get(), &msg, MSG_NOSIGNAL));
-  if (ret < 0)
-    return ErrorStatus(errno);
-  if (static_cast<size_t>(ret) != buffer_.size())
-    return ErrorStatus(EIO);
-  return {};
+  return SendMsgAll(sender, socket_fd, &msg);
 }
 
 // MessageWriter
@@ -132,15 +268,16 @@
 
 Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
                                      ucred* cred) {
+  RecvInterface* receiver = receiver_ ? receiver_ : &g_socket_receiver;
   MessagePreamble preamble;
-  ssize_t ret = RETRY_EINTR(
-      recv(socket_fd.Get(), &preamble, sizeof(preamble), MSG_WAITALL));
-  if (ret < 0)
-    return ErrorStatus(errno);
-  else if (ret == 0)
-    return ErrorStatus(ESHUTDOWN);
-  else if (ret != sizeof(preamble) || preamble.magic != kMagicPreamble)
-    return ErrorStatus(EIO);
+  Status<void> ret = RecvAll(receiver, socket_fd, &preamble, sizeof(preamble));
+  if (!ret)
+    return ret;
+
+  if (preamble.magic != kMagicPreamble) {
+    ret.SetError(EIO);
+    return ret;
+  }
 
   buffer_.resize(preamble.data_size);
   file_handles_.clear();
@@ -159,13 +296,9 @@
     msg.msg_control = alloca(msg.msg_controllen);
   }
 
-  ret = RETRY_EINTR(recvmsg(socket_fd.Get(), &msg, MSG_WAITALL));
-  if (ret < 0)
-    return ErrorStatus(errno);
-  else if (ret == 0)
-    return ErrorStatus(ESHUTDOWN);
-  else if (static_cast<uint32_t>(ret) != preamble.data_size)
-    return ErrorStatus(EIO);
+  ret = RecvMsgAll(receiver, socket_fd, &msg);
+  if (!ret)
+    return ret;
 
   bool cred_available = false;
   file_handles_.reserve(preamble.fd_count);
@@ -186,11 +319,10 @@
     cmsg = CMSG_NXTHDR(&msg, cmsg);
   }
 
-  if (cred && !cred_available) {
-    return ErrorStatus(EIO);
-  }
+  if (cred && !cred_available)
+    ret.SetError(EIO);
 
-  return {};
+  return ret;
 }
 
 // MessageReader
@@ -223,13 +355,7 @@
 
 Status<void> SendData(const BorrowedHandle& socket_fd, const void* data,
                       size_t size) {
-  ssize_t size_written =
-      RETRY_EINTR(send(socket_fd.Get(), data, size, MSG_NOSIGNAL));
-  if (size_written < 0)
-    return ErrorStatus(errno);
-  if (static_cast<size_t>(size_written) != size)
-    return ErrorStatus(EIO);
-  return {};
+  return SendAll(&g_socket_sender, socket_fd, data, size);
 }
 
 Status<void> SendDataVector(const BorrowedHandle& socket_fd, const iovec* data,
@@ -237,26 +363,12 @@
   msghdr msg = {};
   msg.msg_iov = const_cast<iovec*>(data);
   msg.msg_iovlen = count;
-  ssize_t size_written =
-      RETRY_EINTR(sendmsg(socket_fd.Get(), &msg, MSG_NOSIGNAL));
-  if (size_written < 0)
-    return ErrorStatus(errno);
-  if (static_cast<size_t>(size_written) != CountVectorSize(data, count))
-    return ErrorStatus(EIO);
-  return {};
+  return SendMsgAll(&g_socket_sender, socket_fd, &msg);
 }
 
 Status<void> ReceiveData(const BorrowedHandle& socket_fd, void* data,
                          size_t size) {
-  ssize_t size_read =
-      RETRY_EINTR(recv(socket_fd.Get(), data, size, MSG_WAITALL));
-  if (size_read < 0)
-    return ErrorStatus(errno);
-  else if (size_read == 0)
-    return ErrorStatus(ESHUTDOWN);
-  else if (static_cast<size_t>(size_read) != size)
-    return ErrorStatus(EIO);
-  return {};
+  return RecvAll(&g_socket_receiver, socket_fd, data, size);
 }
 
 Status<void> ReceiveDataVector(const BorrowedHandle& socket_fd,
@@ -264,14 +376,7 @@
   msghdr msg = {};
   msg.msg_iov = const_cast<iovec*>(data);
   msg.msg_iovlen = count;
-  ssize_t size_read = RETRY_EINTR(recvmsg(socket_fd.Get(), &msg, MSG_WAITALL));
-  if (size_read < 0)
-    return ErrorStatus(errno);
-  else if (size_read == 0)
-    return ErrorStatus(ESHUTDOWN);
-  else if (static_cast<size_t>(size_read) != CountVectorSize(data, count))
-    return ErrorStatus(EIO);
-  return {};
+  return RecvMsgAll(&g_socket_receiver, socket_fd, &msg);
 }
 
 size_t CountVectorSize(const iovec* vector, size_t count) {
diff --git a/libs/vr/libpdx_uds/ipc_helper_tests.cpp b/libs/vr/libpdx_uds/ipc_helper_tests.cpp
new file mode 100644
index 0000000..bfa827e
--- /dev/null
+++ b/libs/vr/libpdx_uds/ipc_helper_tests.cpp
@@ -0,0 +1,365 @@
+#include "uds/ipc_helper.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+using testing::Return;
+using testing::SetErrnoAndReturn;
+using testing::_;
+
+using android::pdx::BorrowedHandle;
+using android::pdx::uds::SendInterface;
+using android::pdx::uds::RecvInterface;
+using android::pdx::uds::SendAll;
+using android::pdx::uds::SendMsgAll;
+using android::pdx::uds::RecvAll;
+using android::pdx::uds::RecvMsgAll;
+
+namespace {
+
+// Useful constants for tests.
+static constexpr intptr_t kPtr = 1234;
+static constexpr int kSocketFd = 5678;
+static const BorrowedHandle kSocket{kSocketFd};
+
+// Helper functions to construct test data pointer values.
+void* IntToPtr(intptr_t value) { return reinterpret_cast<void*>(value); }
+const void* IntToConstPtr(intptr_t value) {
+  return reinterpret_cast<const void*>(value);
+}
+
+// Mock classes for SendInterface/RecvInterface.
+class MockSender : public SendInterface {
+ public:
+  MOCK_METHOD4(Send, ssize_t(int socket_fd, const void* data, size_t size,
+                             int flags));
+  MOCK_METHOD3(SendMessage,
+               ssize_t(int socket_fd, const msghdr* msg, int flags));
+};
+
+class MockReceiver : public RecvInterface {
+ public:
+  MOCK_METHOD4(Receive,
+               ssize_t(int socket_fd, void* data, size_t size, int flags));
+  MOCK_METHOD3(ReceiveMessage, ssize_t(int socket_fd, msghdr* msg, int flags));
+};
+
+// Test case classes.
+class SendTest : public testing::Test {
+ public:
+  SendTest() {
+    ON_CALL(sender_, Send(_, _, _, _))
+        .WillByDefault(SetErrnoAndReturn(EIO, -1));
+    ON_CALL(sender_, SendMessage(_, _, _))
+        .WillByDefault(SetErrnoAndReturn(EIO, -1));
+  }
+
+ protected:
+  MockSender sender_;
+};
+
+class RecvTest : public testing::Test {
+ public:
+  RecvTest() {
+    ON_CALL(receiver_, Receive(_, _, _, _))
+        .WillByDefault(SetErrnoAndReturn(EIO, -1));
+    ON_CALL(receiver_, ReceiveMessage(_, _, _))
+        .WillByDefault(SetErrnoAndReturn(EIO, -1));
+  }
+
+ protected:
+  MockReceiver receiver_;
+};
+
+class MessageTestBase : public testing::Test {
+ public:
+  MessageTestBase() {
+    memset(&msg_, 0, sizeof(msg_));
+    msg_.msg_iovlen = data_.size();
+    msg_.msg_iov = data_.data();
+  }
+
+ protected:
+  static constexpr intptr_t kPtr1 = kPtr;
+  static constexpr intptr_t kPtr2 = kPtr + 200;
+  static constexpr intptr_t kPtr3 = kPtr + 1000;
+
+  MockSender sender_;
+  msghdr msg_;
+  std::vector<iovec> data_{
+      {IntToPtr(kPtr1), 100}, {IntToPtr(kPtr2), 200}, {IntToPtr(kPtr3), 300}};
+};
+
+class SendMessageTest : public MessageTestBase {
+ public:
+  SendMessageTest() {
+    ON_CALL(sender_, Send(_, _, _, _))
+        .WillByDefault(SetErrnoAndReturn(EIO, -1));
+    ON_CALL(sender_, SendMessage(_, _, _))
+        .WillByDefault(SetErrnoAndReturn(EIO, -1));
+  }
+
+ protected:
+  MockSender sender_;
+};
+
+class RecvMessageTest : public MessageTestBase {
+ public:
+  RecvMessageTest() {
+    ON_CALL(receiver_, Receive(_, _, _, _))
+        .WillByDefault(SetErrnoAndReturn(EIO, -1));
+    ON_CALL(receiver_, ReceiveMessage(_, _, _))
+        .WillByDefault(SetErrnoAndReturn(EIO, -1));
+  }
+
+ protected:
+  MockReceiver receiver_;
+};
+
+// Actual tests.
+
+// SendAll
+TEST_F(SendTest, Complete) {
+  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
+      .WillOnce(Return(100));
+
+  auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
+  EXPECT_TRUE(status);
+}
+
+TEST_F(SendTest, Signal) {
+  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
+      .WillOnce(Return(20));
+  EXPECT_CALL(sender_,
+              Send(kSocketFd, IntToConstPtr(kPtr + 20), 80, MSG_NOSIGNAL))
+      .WillOnce(Return(40));
+  EXPECT_CALL(sender_,
+              Send(kSocketFd, IntToConstPtr(kPtr + 60), 40, MSG_NOSIGNAL))
+      .WillOnce(Return(40));
+
+  auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
+  EXPECT_TRUE(status);
+}
+
+TEST_F(SendTest, Eintr) {
+  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
+      .WillOnce(SetErrnoAndReturn(EINTR, -1))
+      .WillOnce(Return(100));
+
+  auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
+  EXPECT_TRUE(status);
+}
+
+TEST_F(SendTest, Error) {
+  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
+      .WillOnce(SetErrnoAndReturn(EIO, -1));
+
+  auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
+  ASSERT_FALSE(status);
+  EXPECT_EQ(EIO, status.error());
+}
+
+TEST_F(SendTest, Error2) {
+  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
+      .WillOnce(Return(50));
+  EXPECT_CALL(sender_,
+              Send(kSocketFd, IntToConstPtr(kPtr + 50), 50, MSG_NOSIGNAL))
+      .WillOnce(SetErrnoAndReturn(EIO, -1));
+
+  auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
+  ASSERT_FALSE(status);
+  EXPECT_EQ(EIO, status.error());
+}
+
+// RecvAll
+TEST_F(RecvTest, Complete) {
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100,
+                                 MSG_WAITALL | MSG_CMSG_CLOEXEC))
+      .WillOnce(Return(100));
+
+  auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
+  EXPECT_TRUE(status);
+}
+
+TEST_F(RecvTest, Signal) {
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100, _))
+      .WillOnce(Return(20));
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr + 20), 80, _))
+      .WillOnce(Return(40));
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr + 60), 40, _))
+      .WillOnce(Return(40));
+
+  auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
+  EXPECT_TRUE(status);
+}
+
+TEST_F(RecvTest, Eintr) {
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100, _))
+      .WillOnce(SetErrnoAndReturn(EINTR, -1))
+      .WillOnce(Return(100));
+
+  auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
+  EXPECT_TRUE(status);
+}
+
+TEST_F(RecvTest, Error) {
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100, _))
+      .WillOnce(SetErrnoAndReturn(EIO, -1));
+
+  auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
+  ASSERT_FALSE(status);
+  EXPECT_EQ(EIO, status.error());
+}
+
+TEST_F(RecvTest, Error2) {
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100, _))
+      .WillOnce(Return(30));
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr + 30), 70, _))
+      .WillOnce(SetErrnoAndReturn(EIO, -1));
+
+  auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
+  ASSERT_FALSE(status);
+  EXPECT_EQ(EIO, status.error());
+}
+
+// SendMsgAll
+TEST_F(SendMessageTest, Complete) {
+  EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, MSG_NOSIGNAL))
+      .WillOnce(Return(600));
+
+  auto status = SendMsgAll(&sender_, kSocket, &msg_);
+  EXPECT_TRUE(status);
+}
+
+TEST_F(SendMessageTest, Partial) {
+  EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _)).WillOnce(Return(70));
+  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr1 + 70), 30, _))
+      .WillOnce(Return(30));
+  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr2), 200, _))
+      .WillOnce(Return(190));
+  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr2 + 190), 10, _))
+      .WillOnce(Return(10));
+  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr3), 300, _))
+      .WillOnce(Return(300));
+
+  auto status = SendMsgAll(&sender_, kSocket, &msg_);
+  EXPECT_TRUE(status);
+}
+
+TEST_F(SendMessageTest, Partial2) {
+  EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _)).WillOnce(Return(310));
+  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr3 + 10), 290, _))
+      .WillOnce(Return(290));
+
+  auto status = SendMsgAll(&sender_, kSocket, &msg_);
+  EXPECT_TRUE(status);
+}
+
+TEST_F(SendMessageTest, Eintr) {
+  EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _))
+      .WillOnce(SetErrnoAndReturn(EINTR, -1))
+      .WillOnce(Return(70));
+  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr1 + 70), 30, _))
+      .WillOnce(SetErrnoAndReturn(EINTR, -1))
+      .WillOnce(Return(30));
+  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr2), 200, _))
+      .WillOnce(Return(200));
+  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr3), 300, _))
+      .WillOnce(Return(300));
+
+  auto status = SendMsgAll(&sender_, kSocket, &msg_);
+  EXPECT_TRUE(status);
+}
+
+TEST_F(SendMessageTest, Error) {
+  EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _))
+      .WillOnce(SetErrnoAndReturn(EBADF, -1));
+
+  auto status = SendMsgAll(&sender_, kSocket, &msg_);
+  ASSERT_FALSE(status);
+  EXPECT_EQ(EBADF, status.error());
+}
+
+TEST_F(SendMessageTest, Error2) {
+  EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _)).WillOnce(Return(20));
+  EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr1 + 20), 80, _))
+      .WillOnce(SetErrnoAndReturn(EBADF, -1));
+
+  auto status = SendMsgAll(&sender_, kSocket, &msg_);
+  ASSERT_FALSE(status);
+  EXPECT_EQ(EBADF, status.error());
+}
+
+// RecvMsgAll
+TEST_F(RecvMessageTest, Complete) {
+  EXPECT_CALL(receiver_,
+              ReceiveMessage(kSocketFd, &msg_, MSG_WAITALL | MSG_CMSG_CLOEXEC))
+      .WillOnce(Return(600));
+
+  auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
+  EXPECT_TRUE(status);
+}
+
+TEST_F(RecvMessageTest, Partial) {
+  EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
+      .WillOnce(Return(70));
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr1 + 70), 30, _))
+      .WillOnce(Return(30));
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr2), 200, _))
+      .WillOnce(Return(190));
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr2 + 190), 10, _))
+      .WillOnce(Return(10));
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr3), 300, _))
+      .WillOnce(Return(300));
+
+  auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
+  EXPECT_TRUE(status);
+}
+
+TEST_F(RecvMessageTest, Partial2) {
+  EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
+      .WillOnce(Return(310));
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr3 + 10), 290, _))
+      .WillOnce(Return(290));
+
+  auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
+  EXPECT_TRUE(status);
+}
+
+TEST_F(RecvMessageTest, Eintr) {
+  EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
+      .WillOnce(SetErrnoAndReturn(EINTR, -1))
+      .WillOnce(Return(70));
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr1 + 70), 30, _))
+      .WillOnce(SetErrnoAndReturn(EINTR, -1))
+      .WillOnce(Return(30));
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr2), 200, _))
+      .WillOnce(Return(200));
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr3), 300, _))
+      .WillOnce(Return(300));
+
+  auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
+  EXPECT_TRUE(status);
+}
+
+TEST_F(RecvMessageTest, Error) {
+  EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
+      .WillOnce(SetErrnoAndReturn(EBADF, -1));
+
+  auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
+  ASSERT_FALSE(status);
+  EXPECT_EQ(EBADF, status.error());
+}
+
+TEST_F(RecvMessageTest, Error2) {
+  EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
+      .WillOnce(Return(20));
+  EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr1 + 20), 80, _))
+      .WillOnce(SetErrnoAndReturn(EBADF, -1));
+
+  auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
+  ASSERT_FALSE(status);
+  EXPECT_EQ(EBADF, status.error());
+}
+
+}  // namespace
diff --git a/libs/vr/libpdx_uds/private/uds/ipc_helper.h b/libs/vr/libpdx_uds/private/uds/ipc_helper.h
index 82950a2..5b7e5ff 100644
--- a/libs/vr/libpdx_uds/private/uds/ipc_helper.h
+++ b/libs/vr/libpdx_uds/private/uds/ipc_helper.h
@@ -14,6 +14,38 @@
 namespace pdx {
 namespace uds {
 
+// Test interfaces used for unit-testing payload sending/receiving over sockets.
+class SendInterface {
+ public:
+  virtual ssize_t Send(int socket_fd, const void* data, size_t size,
+                       int flags) = 0;
+  virtual ssize_t SendMessage(int socket_fd, const msghdr* msg, int flags) = 0;
+
+ protected:
+  virtual ~SendInterface() = default;
+};
+
+class RecvInterface {
+ public:
+  virtual ssize_t Receive(int socket_fd, void* data, size_t size,
+                          int flags) = 0;
+  virtual ssize_t ReceiveMessage(int socket_fd, msghdr* msg, int flags) = 0;
+
+ protected:
+  virtual ~RecvInterface() = default;
+};
+
+// Helper methods that allow to send/receive data through abstract interfaces.
+// Useful for mocking out the underlying socket I/O.
+Status<void> SendAll(SendInterface* sender, const BorrowedHandle& socket_fd,
+                     const void* data, size_t size);
+Status<void> SendMsgAll(SendInterface* sender, const BorrowedHandle& socket_fd,
+                        const msghdr* msg);
+Status<void> RecvAll(RecvInterface* receiver, const BorrowedHandle& socket_fd,
+                     void* data, size_t size);
+Status<void> RecvMsgAll(RecvInterface* receiver,
+                        const BorrowedHandle& socket_fd, msghdr* msg);
+
 #define RETRY_EINTR(fnc_call)                 \
   ([&]() -> decltype(fnc_call) {              \
     decltype(fnc_call) result;                \
@@ -25,6 +57,7 @@
 
 class SendPayload : public MessageWriter, public OutputResourceMapper {
  public:
+  SendPayload(SendInterface* sender = nullptr) : sender_{sender} {}
   Status<void> Send(const BorrowedHandle& socket_fd);
   Status<void> Send(const BorrowedHandle& socket_fd, const ucred* cred);
 
@@ -44,12 +77,14 @@
       const RemoteChannelHandle& handle) override;
 
  private:
+  SendInterface* sender_;
   ByteBuffer buffer_;
   std::vector<int> file_handles_;
 };
 
 class ReceivePayload : public MessageReader, public InputResourceMapper {
  public:
+  ReceivePayload(RecvInterface* receiver = nullptr) : receiver_{receiver} {}
   Status<void> Receive(const BorrowedHandle& socket_fd);
   Status<void> Receive(const BorrowedHandle& socket_fd, ucred* cred);
 
@@ -64,6 +99,7 @@
                         LocalChannelHandle* handle) override;
 
  private:
+  RecvInterface* receiver_;
   ByteBuffer buffer_;
   std::vector<LocalHandle> file_handles_;
   size_t read_pos_{0};