libpdx_uds: Serialize access to connection socket between threads

Added a mutex to allow only one client thread to perform atomic
send-request/receive-responce actions.

Also added a unit test that perfroms multiple parallel client requests
to the same service to ensure it can handle multithreaded access
correctly.

Bug: 37443070
Test: `libpdx_uds_tests` pass
Change-Id: Ica516f7806f9146fb530b5cb371d2ee89146fed7
diff --git a/libs/vr/libpdx_uds/Android.bp b/libs/vr/libpdx_uds/Android.bp
index f2bcc0c..cfc2022 100644
--- a/libs/vr/libpdx_uds/Android.bp
+++ b/libs/vr/libpdx_uds/Android.bp
@@ -35,6 +35,7 @@
         "-Werror",
     ],
     srcs: [
+        "client_channel_tests.cpp",
         "ipc_helper_tests.cpp",
         "remote_method_tests.cpp",
         "service_framework_tests.cpp",
diff --git a/libs/vr/libpdx_uds/client_channel.cpp b/libs/vr/libpdx_uds/client_channel.cpp
index 924335f..9d91617 100644
--- a/libs/vr/libpdx_uds/client_channel.cpp
+++ b/libs/vr/libpdx_uds/client_channel.cpp
@@ -156,6 +156,7 @@
 
 Status<void> ClientChannel::SendImpulse(int opcode, const void* buffer,
                                         size_t length) {
+  std::unique_lock<std::mutex> lock(socket_mutex_);
   Status<void> status;
   android::pdx::uds::RequestHeader<BorrowedHandle> request;
   if (length > request.impulse_payload.size() ||
@@ -174,6 +175,7 @@
                                           size_t send_count,
                                           const iovec* receive_vector,
                                           size_t receive_count) {
+  std::unique_lock<std::mutex> lock(socket_mutex_);
   Status<int> result;
   if ((send_vector == nullptr && send_count != 0) ||
       (receive_vector == nullptr && receive_count != 0)) {
diff --git a/libs/vr/libpdx_uds/client_channel_tests.cpp b/libs/vr/libpdx_uds/client_channel_tests.cpp
new file mode 100644
index 0000000..7c3c68a
--- /dev/null
+++ b/libs/vr/libpdx_uds/client_channel_tests.cpp
@@ -0,0 +1,162 @@
+#include <uds/client_channel.h>
+
+#include <sys/socket.h>
+
+#include <algorithm>
+#include <limits>
+#include <random>
+#include <thread>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+#include <pdx/client.h>
+#include <pdx/rpc/remote_method.h>
+#include <pdx/service.h>
+
+#include <uds/client_channel_factory.h>
+#include <uds/service_endpoint.h>
+
+using testing::Return;
+using testing::_;
+
+using android::pdx::ClientBase;
+using android::pdx::LocalChannelHandle;
+using android::pdx::LocalHandle;
+using android::pdx::Message;
+using android::pdx::ServiceBase;
+using android::pdx::ServiceDispatcher;
+using android::pdx::Status;
+using android::pdx::rpc::DispatchRemoteMethod;
+using android::pdx::uds::ClientChannel;
+using android::pdx::uds::ClientChannelFactory;
+using android::pdx::uds::Endpoint;
+
+namespace {
+
+struct TestProtocol {
+  using DataType = int8_t;
+  enum {
+    kOpSum = 0,
+  };
+  PDX_REMOTE_METHOD(Sum, kOpSum, int64_t(const std::vector<DataType>&));
+};
+
+class TestService : public ServiceBase<TestService> {
+ public:
+  TestService(std::unique_ptr<Endpoint> endpoint)
+      : ServiceBase{"TestService", std::move(endpoint)} {}
+
+  Status<void> HandleMessage(Message& message) override {
+    switch (message.GetOp()) {
+      case TestProtocol::kOpSum:
+        DispatchRemoteMethod<TestProtocol::Sum>(*this, &TestService::OnSum,
+                                                message);
+        return {};
+
+      default:
+        return Service::HandleMessage(message);
+    }
+  }
+
+  int64_t OnSum(Message& /*message*/,
+                const std::vector<TestProtocol::DataType>& data) {
+    return std::accumulate(data.begin(), data.end(), int64_t{0});
+  }
+};
+
+class TestClient : public ClientBase<TestClient> {
+ public:
+  using ClientBase::ClientBase;
+
+  int64_t Sum(const std::vector<TestProtocol::DataType>& data) {
+    auto status = InvokeRemoteMethod<TestProtocol::Sum>(data);
+    return status ? status.get() : -1;
+  }
+};
+
+class TestServiceRunner {
+ public:
+  TestServiceRunner(LocalHandle channel_socket) {
+    auto endpoint = Endpoint::CreateFromSocketFd(LocalHandle{});
+    endpoint->RegisterNewChannelForTests(std::move(channel_socket));
+    service_ = TestService::Create(std::move(endpoint));
+    dispatcher_ = android::pdx::uds::ServiceDispatcher::Create();
+    dispatcher_->AddService(service_);
+    dispatch_thread_ = std::thread(
+        std::bind(&ServiceDispatcher::EnterDispatchLoop, dispatcher_.get()));
+  }
+
+  ~TestServiceRunner() {
+    dispatcher_->SetCanceled(true);
+    dispatch_thread_.join();
+    dispatcher_->RemoveService(service_);
+  }
+
+ private:
+  std::shared_ptr<TestService> service_;
+  std::unique_ptr<ServiceDispatcher> dispatcher_;
+  std::thread dispatch_thread_;
+};
+
+class ClientChannelTest : public testing::Test {
+ public:
+  void SetUp() override {
+    int channel_sockets[2] = {};
+    ASSERT_EQ(
+        0, socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, channel_sockets));
+    LocalHandle service_channel{channel_sockets[0]};
+    LocalHandle client_channel{channel_sockets[1]};
+
+    service_runner_.reset(new TestServiceRunner{std::move(service_channel)});
+    auto factory = ClientChannelFactory::Create(std::move(client_channel));
+    auto status = factory->Connect(android::pdx::Client::kInfiniteTimeout);
+    ASSERT_TRUE(status);
+    client_ = TestClient::Create(status.take());
+  }
+
+  void TearDown() override {
+    service_runner_.reset();
+    client_.reset();
+  }
+
+ protected:
+  std::unique_ptr<TestServiceRunner> service_runner_;
+  std::shared_ptr<TestClient> client_;
+};
+
+TEST_F(ClientChannelTest, MultithreadedClient) {
+  constexpr int kNumTestThreads = 8;
+  constexpr size_t kDataSize = 1000;  // Try to keep RPC buffer size below 4K.
+
+  std::random_device rd;
+  std::mt19937 gen{rd()};
+  std::uniform_int_distribution<TestProtocol::DataType> dist{
+      std::numeric_limits<TestProtocol::DataType>::min(),
+      std::numeric_limits<TestProtocol::DataType>::max()};
+
+  auto worker = [](std::shared_ptr<TestClient> client,
+                   std::vector<TestProtocol::DataType> data) {
+    constexpr int kMaxIterations = 500;
+    int64_t expected = std::accumulate(data.begin(), data.end(), int64_t{0});
+    for (int i = 0; i < kMaxIterations; i++) {
+      ASSERT_EQ(expected, client->Sum(data));
+    }
+  };
+
+  // Start client threads.
+  std::vector<TestProtocol::DataType> data;
+  data.resize(kDataSize);
+  std::vector<std::thread> threads;
+  for (int i = 0; i < kNumTestThreads; i++) {
+    std::generate(data.begin(), data.end(),
+                  [&dist, &gen]() { return dist(gen); });
+    threads.emplace_back(worker, client_, data);
+  }
+
+  // Wait for threads to finish.
+  for (auto& thread : threads)
+    thread.join();
+}
+
+}  // namespace
diff --git a/libs/vr/libpdx_uds/private/uds/client_channel.h b/libs/vr/libpdx_uds/private/uds/client_channel.h
index 45f6473..8f607f5 100644
--- a/libs/vr/libpdx_uds/private/uds/client_channel.h
+++ b/libs/vr/libpdx_uds/private/uds/client_channel.h
@@ -3,6 +3,8 @@
 
 #include <pdx/client_channel.h>
 
+#include <mutex>
+
 #include <uds/channel_event_set.h>
 #include <uds/channel_manager.h>
 #include <uds/service_endpoint.h>
@@ -73,6 +75,7 @@
 
   LocalChannelHandle channel_handle_;
   ChannelManager::ChannelData* channel_data_;
+  std::mutex socket_mutex_;
 };
 
 }  // namespace uds