libpdx_uds: Fix send/receive over socket to handle signal interrupts
Previous implementation of send/receive didn't account for the fact that
send/receive operation might be interrupted by a signal and transfer
fewer bytes than requested.
Fix this by repeatedly calling send/recv until all the requested data
is transferred over sockets.
Also added a number of unit tests for send/receive functions.
Bug: 37427314
Test: `m -j32` succeeds for Sailfish.
`libpdx_uds_tests` pass on device
Change-Id: Ib8f78967af3c218d9f18fb3dfe8953c35800540b
diff --git a/libs/vr/libpdx_uds/Android.bp b/libs/vr/libpdx_uds/Android.bp
index a73ba34..f2bcc0c 100644
--- a/libs/vr/libpdx_uds/Android.bp
+++ b/libs/vr/libpdx_uds/Android.bp
@@ -35,10 +35,12 @@
"-Werror",
],
srcs: [
+ "ipc_helper_tests.cpp",
"remote_method_tests.cpp",
"service_framework_tests.cpp",
],
static_libs: [
+ "libgmock",
"libpdx_uds",
"libpdx",
],
diff --git a/libs/vr/libpdx_uds/ipc_helper.cpp b/libs/vr/libpdx_uds/ipc_helper.cpp
index d604f62..b675894 100644
--- a/libs/vr/libpdx_uds/ipc_helper.cpp
+++ b/libs/vr/libpdx_uds/ipc_helper.cpp
@@ -18,6 +18,150 @@
namespace pdx {
namespace uds {
+namespace {
+
+// Default implementations of Send/Receive interfaces to use standard socket
+// send/sendmsg/recv/recvmsg functions.
+class SocketSender : public SendInterface {
+ public:
+ ssize_t Send(int socket_fd, const void* data, size_t size,
+ int flags) override {
+ return send(socket_fd, data, size, flags);
+ }
+ ssize_t SendMessage(int socket_fd, const msghdr* msg, int flags) override {
+ return sendmsg(socket_fd, msg, flags);
+ }
+} g_socket_sender;
+
+class SocketReceiver : public RecvInterface {
+ public:
+ ssize_t Receive(int socket_fd, void* data, size_t size, int flags) override {
+ return recv(socket_fd, data, size, flags);
+ }
+ ssize_t ReceiveMessage(int socket_fd, msghdr* msg, int flags) override {
+ return recvmsg(socket_fd, msg, flags);
+ }
+} g_socket_receiver;
+
+} // anonymous namespace
+
+// Helper wrappers around send()/sendmsg() which repeat send() calls on data
+// that was not sent with the initial call to send/sendmsg. This is important to
+// handle transmissions interrupted by signals.
+Status<void> SendAll(SendInterface* sender, const BorrowedHandle& socket_fd,
+ const void* data, size_t size) {
+ Status<void> ret;
+ const uint8_t* ptr = static_cast<const uint8_t*>(data);
+ while (size > 0) {
+ ssize_t size_written =
+ RETRY_EINTR(sender->Send(socket_fd.Get(), ptr, size, MSG_NOSIGNAL));
+ if (size_written < 0) {
+ ret.SetError(errno);
+ ALOGE("SendAll: Failed to send data over socket: %s",
+ ret.GetErrorMessage().c_str());
+ break;
+ }
+ size -= size_written;
+ ptr += size_written;
+ }
+ return ret;
+}
+
+Status<void> SendMsgAll(SendInterface* sender, const BorrowedHandle& socket_fd,
+ const msghdr* msg) {
+ Status<void> ret;
+ ssize_t sent_size =
+ RETRY_EINTR(sender->SendMessage(socket_fd.Get(), msg, MSG_NOSIGNAL));
+ if (sent_size < 0) {
+ ret.SetError(errno);
+ ALOGE("SendMsgAll: Failed to send data over socket: %s",
+ ret.GetErrorMessage().c_str());
+ return ret;
+ }
+
+ ssize_t chunk_start_offset = 0;
+ for (size_t i = 0; i < msg->msg_iovlen; i++) {
+ ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
+ if (sent_size < chunk_end_offset) {
+ size_t offset_within_chunk = sent_size - chunk_start_offset;
+ size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
+ const uint8_t* chunk_base =
+ static_cast<const uint8_t*>(msg->msg_iov[i].iov_base);
+ ret = SendAll(sender, socket_fd, chunk_base + offset_within_chunk,
+ data_size);
+ if (!ret)
+ break;
+ sent_size += data_size;
+ }
+ chunk_start_offset = chunk_end_offset;
+ }
+ return ret;
+}
+
+// Helper wrappers around recv()/recvmsg() which repeat recv() calls on data
+// that was not received with the initial call to recvmsg(). This is important
+// to handle transmissions interrupted by signals as well as the case when
+// initial data did not arrive in a single chunk over the socket (e.g. socket
+// buffer was full at the time of transmission, and only portion of initial
+// message was sent and the rest was blocked until the buffer was cleared by the
+// receiving side).
+Status<void> RecvMsgAll(RecvInterface* receiver,
+ const BorrowedHandle& socket_fd, msghdr* msg) {
+ Status<void> ret;
+ ssize_t size_read = RETRY_EINTR(receiver->ReceiveMessage(
+ socket_fd.Get(), msg, MSG_WAITALL | MSG_CMSG_CLOEXEC));
+ if (size_read < 0) {
+ ret.SetError(errno);
+ ALOGE("RecvMsgAll: Failed to receive data from socket: %s",
+ ret.GetErrorMessage().c_str());
+ return ret;
+ } else if (size_read == 0) {
+ ret.SetError(ESHUTDOWN);
+ ALOGW("RecvMsgAll: Socket has been shut down");
+ return ret;
+ }
+
+ ssize_t chunk_start_offset = 0;
+ for (size_t i = 0; i < msg->msg_iovlen; i++) {
+ ssize_t chunk_end_offset = chunk_start_offset + msg->msg_iov[i].iov_len;
+ if (size_read < chunk_end_offset) {
+ size_t offset_within_chunk = size_read - chunk_start_offset;
+ size_t data_size = msg->msg_iov[i].iov_len - offset_within_chunk;
+ uint8_t* chunk_base = static_cast<uint8_t*>(msg->msg_iov[i].iov_base);
+ ret = RecvAll(receiver, socket_fd, chunk_base + offset_within_chunk,
+ data_size);
+ if (!ret)
+ break;
+ size_read += data_size;
+ }
+ chunk_start_offset = chunk_end_offset;
+ }
+ return ret;
+}
+
+Status<void> RecvAll(RecvInterface* receiver, const BorrowedHandle& socket_fd,
+ void* data, size_t size) {
+ Status<void> ret;
+ uint8_t* ptr = static_cast<uint8_t*>(data);
+ while (size > 0) {
+ ssize_t size_read = RETRY_EINTR(receiver->Receive(
+ socket_fd.Get(), ptr, size, MSG_WAITALL | MSG_CMSG_CLOEXEC));
+ if (size_read < 0) {
+ ret.SetError(errno);
+ ALOGE("RecvAll: Failed to receive data from socket: %s",
+ ret.GetErrorMessage().c_str());
+ break;
+ } else if (size_read == 0) {
+ ret.SetError(ESHUTDOWN);
+ ALOGW("RecvAll: Socket has been shut down");
+ break;
+ }
+ size -= size_read;
+ ptr += size_read;
+ }
+ return ret;
+}
+
uint32_t kMagicPreamble = 0x7564736d; // 'udsm'.
struct MessagePreamble {
@@ -32,17 +176,14 @@
Status<void> SendPayload::Send(const BorrowedHandle& socket_fd,
const ucred* cred) {
+ SendInterface* sender = sender_ ? sender_ : &g_socket_sender;
MessagePreamble preamble;
preamble.magic = kMagicPreamble;
preamble.data_size = buffer_.size();
preamble.fd_count = file_handles_.size();
-
- ssize_t ret = RETRY_EINTR(
- send(socket_fd.Get(), &preamble, sizeof(preamble), MSG_NOSIGNAL));
- if (ret < 0)
- return ErrorStatus(errno);
- if (ret != sizeof(preamble))
- return ErrorStatus(EIO);
+ Status<void> ret = SendAll(sender, socket_fd, &preamble, sizeof(preamble));
+ if (!ret)
+ return ret;
msghdr msg = {};
iovec recv_vect = {buffer_.data(), buffer_.size()};
@@ -72,12 +213,7 @@
}
}
- ret = RETRY_EINTR(sendmsg(socket_fd.Get(), &msg, MSG_NOSIGNAL));
- if (ret < 0)
- return ErrorStatus(errno);
- if (static_cast<size_t>(ret) != buffer_.size())
- return ErrorStatus(EIO);
- return {};
+ return SendMsgAll(sender, socket_fd, &msg);
}
// MessageWriter
@@ -132,15 +268,16 @@
Status<void> ReceivePayload::Receive(const BorrowedHandle& socket_fd,
ucred* cred) {
+ RecvInterface* receiver = receiver_ ? receiver_ : &g_socket_receiver;
MessagePreamble preamble;
- ssize_t ret = RETRY_EINTR(
- recv(socket_fd.Get(), &preamble, sizeof(preamble), MSG_WAITALL));
- if (ret < 0)
- return ErrorStatus(errno);
- else if (ret == 0)
- return ErrorStatus(ESHUTDOWN);
- else if (ret != sizeof(preamble) || preamble.magic != kMagicPreamble)
- return ErrorStatus(EIO);
+ Status<void> ret = RecvAll(receiver, socket_fd, &preamble, sizeof(preamble));
+ if (!ret)
+ return ret;
+
+ if (preamble.magic != kMagicPreamble) {
+ ret.SetError(EIO);
+ return ret;
+ }
buffer_.resize(preamble.data_size);
file_handles_.clear();
@@ -159,13 +296,9 @@
msg.msg_control = alloca(msg.msg_controllen);
}
- ret = RETRY_EINTR(recvmsg(socket_fd.Get(), &msg, MSG_WAITALL));
- if (ret < 0)
- return ErrorStatus(errno);
- else if (ret == 0)
- return ErrorStatus(ESHUTDOWN);
- else if (static_cast<uint32_t>(ret) != preamble.data_size)
- return ErrorStatus(EIO);
+ ret = RecvMsgAll(receiver, socket_fd, &msg);
+ if (!ret)
+ return ret;
bool cred_available = false;
file_handles_.reserve(preamble.fd_count);
@@ -186,11 +319,10 @@
cmsg = CMSG_NXTHDR(&msg, cmsg);
}
- if (cred && !cred_available) {
- return ErrorStatus(EIO);
- }
+ if (cred && !cred_available)
+ ret.SetError(EIO);
- return {};
+ return ret;
}
// MessageReader
@@ -223,13 +355,7 @@
Status<void> SendData(const BorrowedHandle& socket_fd, const void* data,
size_t size) {
- ssize_t size_written =
- RETRY_EINTR(send(socket_fd.Get(), data, size, MSG_NOSIGNAL));
- if (size_written < 0)
- return ErrorStatus(errno);
- if (static_cast<size_t>(size_written) != size)
- return ErrorStatus(EIO);
- return {};
+ return SendAll(&g_socket_sender, socket_fd, data, size);
}
Status<void> SendDataVector(const BorrowedHandle& socket_fd, const iovec* data,
@@ -237,26 +363,12 @@
msghdr msg = {};
msg.msg_iov = const_cast<iovec*>(data);
msg.msg_iovlen = count;
- ssize_t size_written =
- RETRY_EINTR(sendmsg(socket_fd.Get(), &msg, MSG_NOSIGNAL));
- if (size_written < 0)
- return ErrorStatus(errno);
- if (static_cast<size_t>(size_written) != CountVectorSize(data, count))
- return ErrorStatus(EIO);
- return {};
+ return SendMsgAll(&g_socket_sender, socket_fd, &msg);
}
Status<void> ReceiveData(const BorrowedHandle& socket_fd, void* data,
size_t size) {
- ssize_t size_read =
- RETRY_EINTR(recv(socket_fd.Get(), data, size, MSG_WAITALL));
- if (size_read < 0)
- return ErrorStatus(errno);
- else if (size_read == 0)
- return ErrorStatus(ESHUTDOWN);
- else if (static_cast<size_t>(size_read) != size)
- return ErrorStatus(EIO);
- return {};
+ return RecvAll(&g_socket_receiver, socket_fd, data, size);
}
Status<void> ReceiveDataVector(const BorrowedHandle& socket_fd,
@@ -264,14 +376,7 @@
msghdr msg = {};
msg.msg_iov = const_cast<iovec*>(data);
msg.msg_iovlen = count;
- ssize_t size_read = RETRY_EINTR(recvmsg(socket_fd.Get(), &msg, MSG_WAITALL));
- if (size_read < 0)
- return ErrorStatus(errno);
- else if (size_read == 0)
- return ErrorStatus(ESHUTDOWN);
- else if (static_cast<size_t>(size_read) != CountVectorSize(data, count))
- return ErrorStatus(EIO);
- return {};
+ return RecvMsgAll(&g_socket_receiver, socket_fd, &msg);
}
size_t CountVectorSize(const iovec* vector, size_t count) {
diff --git a/libs/vr/libpdx_uds/ipc_helper_tests.cpp b/libs/vr/libpdx_uds/ipc_helper_tests.cpp
new file mode 100644
index 0000000..bfa827e
--- /dev/null
+++ b/libs/vr/libpdx_uds/ipc_helper_tests.cpp
@@ -0,0 +1,365 @@
+#include "uds/ipc_helper.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+using testing::Return;
+using testing::SetErrnoAndReturn;
+using testing::_;
+
+using android::pdx::BorrowedHandle;
+using android::pdx::uds::SendInterface;
+using android::pdx::uds::RecvInterface;
+using android::pdx::uds::SendAll;
+using android::pdx::uds::SendMsgAll;
+using android::pdx::uds::RecvAll;
+using android::pdx::uds::RecvMsgAll;
+
+namespace {
+
+// Useful constants for tests.
+static constexpr intptr_t kPtr = 1234;
+static constexpr int kSocketFd = 5678;
+static const BorrowedHandle kSocket{kSocketFd};
+
+// Helper functions to construct test data pointer values.
+void* IntToPtr(intptr_t value) { return reinterpret_cast<void*>(value); }
+const void* IntToConstPtr(intptr_t value) {
+ return reinterpret_cast<const void*>(value);
+}
+
+// Mock classes for SendInterface/RecvInterface.
+class MockSender : public SendInterface {
+ public:
+ MOCK_METHOD4(Send, ssize_t(int socket_fd, const void* data, size_t size,
+ int flags));
+ MOCK_METHOD3(SendMessage,
+ ssize_t(int socket_fd, const msghdr* msg, int flags));
+};
+
+class MockReceiver : public RecvInterface {
+ public:
+ MOCK_METHOD4(Receive,
+ ssize_t(int socket_fd, void* data, size_t size, int flags));
+ MOCK_METHOD3(ReceiveMessage, ssize_t(int socket_fd, msghdr* msg, int flags));
+};
+
+// Test case classes.
+class SendTest : public testing::Test {
+ public:
+ SendTest() {
+ ON_CALL(sender_, Send(_, _, _, _))
+ .WillByDefault(SetErrnoAndReturn(EIO, -1));
+ ON_CALL(sender_, SendMessage(_, _, _))
+ .WillByDefault(SetErrnoAndReturn(EIO, -1));
+ }
+
+ protected:
+ MockSender sender_;
+};
+
+class RecvTest : public testing::Test {
+ public:
+ RecvTest() {
+ ON_CALL(receiver_, Receive(_, _, _, _))
+ .WillByDefault(SetErrnoAndReturn(EIO, -1));
+ ON_CALL(receiver_, ReceiveMessage(_, _, _))
+ .WillByDefault(SetErrnoAndReturn(EIO, -1));
+ }
+
+ protected:
+ MockReceiver receiver_;
+};
+
+class MessageTestBase : public testing::Test {
+ public:
+ MessageTestBase() {
+ memset(&msg_, 0, sizeof(msg_));
+ msg_.msg_iovlen = data_.size();
+ msg_.msg_iov = data_.data();
+ }
+
+ protected:
+ static constexpr intptr_t kPtr1 = kPtr;
+ static constexpr intptr_t kPtr2 = kPtr + 200;
+ static constexpr intptr_t kPtr3 = kPtr + 1000;
+
+ MockSender sender_;
+ msghdr msg_;
+ std::vector<iovec> data_{
+ {IntToPtr(kPtr1), 100}, {IntToPtr(kPtr2), 200}, {IntToPtr(kPtr3), 300}};
+};
+
+class SendMessageTest : public MessageTestBase {
+ public:
+ SendMessageTest() {
+ ON_CALL(sender_, Send(_, _, _, _))
+ .WillByDefault(SetErrnoAndReturn(EIO, -1));
+ ON_CALL(sender_, SendMessage(_, _, _))
+ .WillByDefault(SetErrnoAndReturn(EIO, -1));
+ }
+
+ protected:
+ MockSender sender_;
+};
+
+class RecvMessageTest : public MessageTestBase {
+ public:
+ RecvMessageTest() {
+ ON_CALL(receiver_, Receive(_, _, _, _))
+ .WillByDefault(SetErrnoAndReturn(EIO, -1));
+ ON_CALL(receiver_, ReceiveMessage(_, _, _))
+ .WillByDefault(SetErrnoAndReturn(EIO, -1));
+ }
+
+ protected:
+ MockReceiver receiver_;
+};
+
+// Actual tests.
+
+// SendAll
+TEST_F(SendTest, Complete) {
+ EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
+ .WillOnce(Return(100));
+
+ auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
+ EXPECT_TRUE(status);
+}
+
+TEST_F(SendTest, Signal) {
+ EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
+ .WillOnce(Return(20));
+ EXPECT_CALL(sender_,
+ Send(kSocketFd, IntToConstPtr(kPtr + 20), 80, MSG_NOSIGNAL))
+ .WillOnce(Return(40));
+ EXPECT_CALL(sender_,
+ Send(kSocketFd, IntToConstPtr(kPtr + 60), 40, MSG_NOSIGNAL))
+ .WillOnce(Return(40));
+
+ auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
+ EXPECT_TRUE(status);
+}
+
+TEST_F(SendTest, Eintr) {
+ EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
+ .WillOnce(SetErrnoAndReturn(EINTR, -1))
+ .WillOnce(Return(100));
+
+ auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
+ EXPECT_TRUE(status);
+}
+
+TEST_F(SendTest, Error) {
+ EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
+ .WillOnce(SetErrnoAndReturn(EIO, -1));
+
+ auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
+ ASSERT_FALSE(status);
+ EXPECT_EQ(EIO, status.error());
+}
+
+TEST_F(SendTest, Error2) {
+ EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr), 100, MSG_NOSIGNAL))
+ .WillOnce(Return(50));
+ EXPECT_CALL(sender_,
+ Send(kSocketFd, IntToConstPtr(kPtr + 50), 50, MSG_NOSIGNAL))
+ .WillOnce(SetErrnoAndReturn(EIO, -1));
+
+ auto status = SendAll(&sender_, kSocket, IntToConstPtr(kPtr), 100);
+ ASSERT_FALSE(status);
+ EXPECT_EQ(EIO, status.error());
+}
+
+// RecvAll
+TEST_F(RecvTest, Complete) {
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100,
+ MSG_WAITALL | MSG_CMSG_CLOEXEC))
+ .WillOnce(Return(100));
+
+ auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
+ EXPECT_TRUE(status);
+}
+
+TEST_F(RecvTest, Signal) {
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100, _))
+ .WillOnce(Return(20));
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr + 20), 80, _))
+ .WillOnce(Return(40));
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr + 60), 40, _))
+ .WillOnce(Return(40));
+
+ auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
+ EXPECT_TRUE(status);
+}
+
+TEST_F(RecvTest, Eintr) {
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100, _))
+ .WillOnce(SetErrnoAndReturn(EINTR, -1))
+ .WillOnce(Return(100));
+
+ auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
+ EXPECT_TRUE(status);
+}
+
+TEST_F(RecvTest, Error) {
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100, _))
+ .WillOnce(SetErrnoAndReturn(EIO, -1));
+
+ auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
+ ASSERT_FALSE(status);
+ EXPECT_EQ(EIO, status.error());
+}
+
+TEST_F(RecvTest, Error2) {
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr), 100, _))
+ .WillOnce(Return(30));
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr + 30), 70, _))
+ .WillOnce(SetErrnoAndReturn(EIO, -1));
+
+ auto status = RecvAll(&receiver_, kSocket, IntToPtr(kPtr), 100);
+ ASSERT_FALSE(status);
+ EXPECT_EQ(EIO, status.error());
+}
+
+// SendMsgAll
+TEST_F(SendMessageTest, Complete) {
+ EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, MSG_NOSIGNAL))
+ .WillOnce(Return(600));
+
+ auto status = SendMsgAll(&sender_, kSocket, &msg_);
+ EXPECT_TRUE(status);
+}
+
+TEST_F(SendMessageTest, Partial) {
+ EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _)).WillOnce(Return(70));
+ EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr1 + 70), 30, _))
+ .WillOnce(Return(30));
+ EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr2), 200, _))
+ .WillOnce(Return(190));
+ EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr2 + 190), 10, _))
+ .WillOnce(Return(10));
+ EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr3), 300, _))
+ .WillOnce(Return(300));
+
+ auto status = SendMsgAll(&sender_, kSocket, &msg_);
+ EXPECT_TRUE(status);
+}
+
+TEST_F(SendMessageTest, Partial2) {
+ EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _)).WillOnce(Return(310));
+ EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr3 + 10), 290, _))
+ .WillOnce(Return(290));
+
+ auto status = SendMsgAll(&sender_, kSocket, &msg_);
+ EXPECT_TRUE(status);
+}
+
+TEST_F(SendMessageTest, Eintr) {
+ EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _))
+ .WillOnce(SetErrnoAndReturn(EINTR, -1))
+ .WillOnce(Return(70));
+ EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr1 + 70), 30, _))
+ .WillOnce(SetErrnoAndReturn(EINTR, -1))
+ .WillOnce(Return(30));
+ EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr2), 200, _))
+ .WillOnce(Return(200));
+ EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr3), 300, _))
+ .WillOnce(Return(300));
+
+ auto status = SendMsgAll(&sender_, kSocket, &msg_);
+ EXPECT_TRUE(status);
+}
+
+TEST_F(SendMessageTest, Error) {
+ EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _))
+ .WillOnce(SetErrnoAndReturn(EBADF, -1));
+
+ auto status = SendMsgAll(&sender_, kSocket, &msg_);
+ ASSERT_FALSE(status);
+ EXPECT_EQ(EBADF, status.error());
+}
+
+TEST_F(SendMessageTest, Error2) {
+ EXPECT_CALL(sender_, SendMessage(kSocketFd, &msg_, _)).WillOnce(Return(20));
+ EXPECT_CALL(sender_, Send(kSocketFd, IntToConstPtr(kPtr1 + 20), 80, _))
+ .WillOnce(SetErrnoAndReturn(EBADF, -1));
+
+ auto status = SendMsgAll(&sender_, kSocket, &msg_);
+ ASSERT_FALSE(status);
+ EXPECT_EQ(EBADF, status.error());
+}
+
+// RecvMsgAll
+TEST_F(RecvMessageTest, Complete) {
+ EXPECT_CALL(receiver_,
+ ReceiveMessage(kSocketFd, &msg_, MSG_WAITALL | MSG_CMSG_CLOEXEC))
+ .WillOnce(Return(600));
+
+ auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
+ EXPECT_TRUE(status);
+}
+
+TEST_F(RecvMessageTest, Partial) {
+ EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
+ .WillOnce(Return(70));
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr1 + 70), 30, _))
+ .WillOnce(Return(30));
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr2), 200, _))
+ .WillOnce(Return(190));
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr2 + 190), 10, _))
+ .WillOnce(Return(10));
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr3), 300, _))
+ .WillOnce(Return(300));
+
+ auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
+ EXPECT_TRUE(status);
+}
+
+TEST_F(RecvMessageTest, Partial2) {
+ EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
+ .WillOnce(Return(310));
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr3 + 10), 290, _))
+ .WillOnce(Return(290));
+
+ auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
+ EXPECT_TRUE(status);
+}
+
+TEST_F(RecvMessageTest, Eintr) {
+ EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
+ .WillOnce(SetErrnoAndReturn(EINTR, -1))
+ .WillOnce(Return(70));
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr1 + 70), 30, _))
+ .WillOnce(SetErrnoAndReturn(EINTR, -1))
+ .WillOnce(Return(30));
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr2), 200, _))
+ .WillOnce(Return(200));
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr3), 300, _))
+ .WillOnce(Return(300));
+
+ auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
+ EXPECT_TRUE(status);
+}
+
+TEST_F(RecvMessageTest, Error) {
+ EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
+ .WillOnce(SetErrnoAndReturn(EBADF, -1));
+
+ auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
+ ASSERT_FALSE(status);
+ EXPECT_EQ(EBADF, status.error());
+}
+
+TEST_F(RecvMessageTest, Error2) {
+ EXPECT_CALL(receiver_, ReceiveMessage(kSocketFd, &msg_, _))
+ .WillOnce(Return(20));
+ EXPECT_CALL(receiver_, Receive(kSocketFd, IntToPtr(kPtr1 + 20), 80, _))
+ .WillOnce(SetErrnoAndReturn(EBADF, -1));
+
+ auto status = RecvMsgAll(&receiver_, kSocket, &msg_);
+ ASSERT_FALSE(status);
+ EXPECT_EQ(EBADF, status.error());
+}
+
+} // namespace
diff --git a/libs/vr/libpdx_uds/private/uds/ipc_helper.h b/libs/vr/libpdx_uds/private/uds/ipc_helper.h
index 82950a2..5b7e5ff 100644
--- a/libs/vr/libpdx_uds/private/uds/ipc_helper.h
+++ b/libs/vr/libpdx_uds/private/uds/ipc_helper.h
@@ -14,6 +14,38 @@
namespace pdx {
namespace uds {
+// Test interfaces used for unit-testing payload sending/receiving over sockets.
+class SendInterface {
+ public:
+ virtual ssize_t Send(int socket_fd, const void* data, size_t size,
+ int flags) = 0;
+ virtual ssize_t SendMessage(int socket_fd, const msghdr* msg, int flags) = 0;
+
+ protected:
+ virtual ~SendInterface() = default;
+};
+
+class RecvInterface {
+ public:
+ virtual ssize_t Receive(int socket_fd, void* data, size_t size,
+ int flags) = 0;
+ virtual ssize_t ReceiveMessage(int socket_fd, msghdr* msg, int flags) = 0;
+
+ protected:
+ virtual ~RecvInterface() = default;
+};
+
+// Helper methods that allow to send/receive data through abstract interfaces.
+// Useful for mocking out the underlying socket I/O.
+Status<void> SendAll(SendInterface* sender, const BorrowedHandle& socket_fd,
+ const void* data, size_t size);
+Status<void> SendMsgAll(SendInterface* sender, const BorrowedHandle& socket_fd,
+ const msghdr* msg);
+Status<void> RecvAll(RecvInterface* receiver, const BorrowedHandle& socket_fd,
+ void* data, size_t size);
+Status<void> RecvMsgAll(RecvInterface* receiver,
+ const BorrowedHandle& socket_fd, msghdr* msg);
+
#define RETRY_EINTR(fnc_call) \
([&]() -> decltype(fnc_call) { \
decltype(fnc_call) result; \
@@ -25,6 +57,7 @@
class SendPayload : public MessageWriter, public OutputResourceMapper {
public:
+ SendPayload(SendInterface* sender = nullptr) : sender_{sender} {}
Status<void> Send(const BorrowedHandle& socket_fd);
Status<void> Send(const BorrowedHandle& socket_fd, const ucred* cred);
@@ -44,12 +77,14 @@
const RemoteChannelHandle& handle) override;
private:
+ SendInterface* sender_;
ByteBuffer buffer_;
std::vector<int> file_handles_;
};
class ReceivePayload : public MessageReader, public InputResourceMapper {
public:
+ ReceivePayload(RecvInterface* receiver = nullptr) : receiver_{receiver} {}
Status<void> Receive(const BorrowedHandle& socket_fd);
Status<void> Receive(const BorrowedHandle& socket_fd, ucred* cred);
@@ -64,6 +99,7 @@
LocalChannelHandle* handle) override;
private:
+ RecvInterface* receiver_;
ByteBuffer buffer_;
std::vector<LocalHandle> file_handles_;
size_t read_pos_{0};