libpdx_uds: Allow to create Endpoint/ClientChannel from a socket pair

This is important to enable Service/Client operation in unit tests.
Being able to create a pair of Unix domain sockets and construct both
Service and Client so that they can talk to each other without having
to create a physical socket file is convenient.

This change makes it possible to create an instance of Endpoint and
ClientChannel classes based just on a pair of sockets (Endpoint does
take another socket to simulate the main endpoint FD to accept incoming
connection on, but it is not used for this, only the shutdown events
are received from that main socket. Endpoint uses the channel FD to
perform actual communication with the client).

Bug: 37443070
Test: `libpdx_uds_tests` pass
Change-Id: Ifa1a9d03b97bd90282a04715c2105ad37a8de936
diff --git a/libs/vr/libpdx_uds/client_channel_factory.cpp b/libs/vr/libpdx_uds/client_channel_factory.cpp
index 9202cd5..850c6d3 100644
--- a/libs/vr/libpdx_uds/client_channel_factory.cpp
+++ b/libs/vr/libpdx_uds/client_channel_factory.cpp
@@ -39,32 +39,42 @@
 ClientChannelFactory::ClientChannelFactory(const std::string& endpoint_path)
     : endpoint_path_{GetEndpointPath(endpoint_path)} {}
 
+ClientChannelFactory::ClientChannelFactory(LocalHandle socket)
+    : socket_{std::move(socket)} {}
+
 std::unique_ptr<pdx::ClientChannelFactory> ClientChannelFactory::Create(
     const std::string& endpoint_path) {
   return std::unique_ptr<pdx::ClientChannelFactory>{
       new ClientChannelFactory{endpoint_path}};
 }
 
+std::unique_ptr<pdx::ClientChannelFactory> ClientChannelFactory::Create(
+    LocalHandle socket) {
+  return std::unique_ptr<pdx::ClientChannelFactory>{
+      new ClientChannelFactory{std::move(socket)}};
+}
+
 Status<std::unique_ptr<pdx::ClientChannel>> ClientChannelFactory::Connect(
     int64_t timeout_ms) const {
   Status<void> status;
 
-  LocalHandle socket_fd{socket(AF_UNIX, SOCK_STREAM, 0)};
-  if (!socket_fd) {
+  bool connected = socket_.IsValid();
+  if (!connected) {
+    socket_.Reset(socket(AF_UNIX, SOCK_STREAM, 0));
+    LOG_ALWAYS_FATAL_IF(
+        endpoint_path_.empty(),
+        "ClientChannelFactory::Connect: unspecified socket path");
+  }
+
+  if (!socket_) {
     ALOGE("ClientChannelFactory::Connect: socket error: %s", strerror(errno));
     return ErrorStatus(errno);
   }
 
-  sockaddr_un remote;
-  remote.sun_family = AF_UNIX;
-  strncpy(remote.sun_path, endpoint_path_.c_str(), sizeof(remote.sun_path));
-  remote.sun_path[sizeof(remote.sun_path) - 1] = '\0';
-
   bool use_timeout = (timeout_ms >= 0);
   auto now = steady_clock::now();
   auto time_end = now + std::chrono::milliseconds{timeout_ms};
 
-  bool connected = false;
   int max_eaccess = 5;  // Max number of times to retry when EACCES returned.
   while (!connected) {
     int64_t timeout = -1;
@@ -74,6 +84,10 @@
       if (timeout < 0)
         return ErrorStatus(ETIMEDOUT);
     }
+    sockaddr_un remote;
+    remote.sun_family = AF_UNIX;
+    strncpy(remote.sun_path, endpoint_path_.c_str(), sizeof(remote.sun_path));
+    remote.sun_path[sizeof(remote.sun_path) - 1] = '\0';
     ALOGD("ClientChannelFactory: Waiting for endpoint at %s", remote.sun_path);
     status = WaitForEndpoint(endpoint_path_, timeout);
     if (!status)
@@ -81,7 +95,7 @@
 
     ALOGD("ClientChannelFactory: Connecting to %s", remote.sun_path);
     int ret = RETRY_EINTR(connect(
-        socket_fd.Get(), reinterpret_cast<sockaddr*>(&remote), sizeof(remote)));
+        socket_.Get(), reinterpret_cast<sockaddr*>(&remote), sizeof(remote)));
     if (ret == -1) {
       ALOGD("ClientChannelFactory: Connect error %d: %s", errno,
             strerror(errno));
@@ -107,20 +121,20 @@
       }
     } else {
       connected = true;
+      ALOGD("ClientChannelFactory: Connected successfully to %s...",
+            remote.sun_path);
     }
     if (use_timeout)
       now = steady_clock::now();
   }  // while (!connected)
 
-  ALOGD("ClientChannelFactory: Connected successfully to %s...",
-        remote.sun_path);
   RequestHeader<BorrowedHandle> request;
   InitRequest(&request, opcodes::CHANNEL_OPEN, 0, 0, false);
-  status = SendData(socket_fd.Borrow(), request);
+  status = SendData(socket_.Borrow(), request);
   if (!status)
     return ErrorStatus(status.error());
   ResponseHeader<LocalHandle> response;
-  status = ReceiveData(socket_fd.Borrow(), &response);
+  status = ReceiveData(socket_.Borrow(), &response);
   if (!status)
     return ErrorStatus(status.error());
   int ref = response.ret_code;
@@ -129,7 +143,7 @@
 
   LocalHandle event_fd = std::move(response.file_descriptors[ref]);
   return ClientChannel::Create(ChannelManager::Get().CreateHandle(
-      std::move(socket_fd), std::move(event_fd)));
+      std::move(socket_), std::move(event_fd)));
 }
 
 }  // namespace uds
diff --git a/libs/vr/libpdx_uds/ipc_helper.cpp b/libs/vr/libpdx_uds/ipc_helper.cpp
index b675894..d75ce86 100644
--- a/libs/vr/libpdx_uds/ipc_helper.cpp
+++ b/libs/vr/libpdx_uds/ipc_helper.cpp
@@ -275,6 +275,7 @@
     return ret;
 
   if (preamble.magic != kMagicPreamble) {
+    ALOGE("ReceivePayload::Receive: Message header is invalid");
     ret.SetError(EIO);
     return ret;
   }
@@ -319,8 +320,10 @@
     cmsg = CMSG_NXTHDR(&msg, cmsg);
   }
 
-  if (cred && !cred_available)
+  if (cred && !cred_available) {
+    ALOGE("ReceivePayload::Receive: Failed to obtain message credentials");
     ret.SetError(EIO);
+  }
 
   return ret;
 }
diff --git a/libs/vr/libpdx_uds/private/uds/client_channel_factory.h b/libs/vr/libpdx_uds/private/uds/client_channel_factory.h
index 6f80d31..c43c5c7 100644
--- a/libs/vr/libpdx_uds/private/uds/client_channel_factory.h
+++ b/libs/vr/libpdx_uds/private/uds/client_channel_factory.h
@@ -13,6 +13,7 @@
  public:
   static std::unique_ptr<pdx::ClientChannelFactory> Create(
       const std::string& endpoint_path);
+  static std::unique_ptr<pdx::ClientChannelFactory> Create(LocalHandle socket);
 
   Status<std::unique_ptr<pdx::ClientChannel>> Connect(
       int64_t timeout_ms) const override;
@@ -22,7 +23,9 @@
 
  private:
   explicit ClientChannelFactory(const std::string& endpoint_path);
+  explicit ClientChannelFactory(LocalHandle socket);
 
+  mutable LocalHandle socket_;
   std::string endpoint_path_;
 };
 
diff --git a/libs/vr/libpdx_uds/private/uds/service_endpoint.h b/libs/vr/libpdx_uds/private/uds/service_endpoint.h
index f747abc..eb87827 100644
--- a/libs/vr/libpdx_uds/private/uds/service_endpoint.h
+++ b/libs/vr/libpdx_uds/private/uds/service_endpoint.h
@@ -97,6 +97,14 @@
   static std::unique_ptr<Endpoint> CreateAndBindSocket(
       const std::string& endpoint_path, bool blocking = kDefaultBlocking);
 
+  // Helper method to create an endpoint from an existing socket FD.
+  // Mostly helpful for tests.
+  static std::unique_ptr<Endpoint> CreateFromSocketFd(LocalHandle socket_fd);
+
+  // Test helper method to register a new channel identified by |channel_fd|
+  // socket file descriptor.
+  Status<void> RegisterNewChannelForTests(LocalHandle channel_fd);
+
   int epoll_fd() const { return epoll_fd_.Get(); }
 
  private:
@@ -109,6 +117,9 @@
   // This class must be instantiated using Create() static methods above.
   Endpoint(const std::string& endpoint_path, bool blocking,
            bool use_init_socket_fd = true);
+  Endpoint(LocalHandle socket_fd);
+
+  void Init(LocalHandle socket_fd);
 
   Endpoint(const Endpoint&) = delete;
   void operator=(const Endpoint&) = delete;
diff --git a/libs/vr/libpdx_uds/service_endpoint.cpp b/libs/vr/libpdx_uds/service_endpoint.cpp
index 65fd59f..6c92259 100644
--- a/libs/vr/libpdx_uds/service_endpoint.cpp
+++ b/libs/vr/libpdx_uds/service_endpoint.cpp
@@ -161,9 +161,16 @@
         bind(fd.Get(), reinterpret_cast<sockaddr*>(&local), sizeof(local));
     CHECK_EQ(ret, 0) << "Endpoint::Endpoint: bind error: " << strerror(errno);
   }
-  CHECK_EQ(listen(fd.Get(), kMaxBackLogForSocketListen), 0)
-      << "Endpoint::Endpoint: listen error: " << strerror(errno);
+  Init(std::move(fd));
+}
 
+Endpoint::Endpoint(LocalHandle socket_fd) { Init(std::move(socket_fd)); }
+
+void Endpoint::Init(LocalHandle socket_fd) {
+  if (socket_fd) {
+    CHECK_EQ(listen(socket_fd.Get(), kMaxBackLogForSocketListen), 0)
+        << "Endpoint::Endpoint: listen error: " << strerror(errno);
+  }
   cancel_event_fd_.Reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
   CHECK(cancel_event_fd_.IsValid())
       << "Endpoint::Endpoint: Failed to create event fd: " << strerror(errno);
@@ -172,24 +179,27 @@
   CHECK(epoll_fd_.IsValid())
       << "Endpoint::Endpoint: Failed to create epoll fd: " << strerror(errno);
 
-  epoll_event socket_event;
-  socket_event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
-  socket_event.data.fd = fd.Get();
+  if (socket_fd) {
+    epoll_event socket_event;
+    socket_event.events = EPOLLIN | EPOLLRDHUP | EPOLLONESHOT;
+    socket_event.data.fd = socket_fd.Get();
+    int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, socket_fd.Get(),
+                        &socket_event);
+    CHECK_EQ(ret, 0)
+        << "Endpoint::Endpoint: Failed to add socket fd to epoll fd: "
+        << strerror(errno);
+  }
 
   epoll_event cancel_event;
   cancel_event.events = EPOLLIN;
   cancel_event.data.fd = cancel_event_fd_.Get();
 
-  int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, fd.Get(), &socket_event);
-  CHECK_EQ(ret, 0)
-      << "Endpoint::Endpoint: Failed to add socket fd to epoll fd: "
-      << strerror(errno);
-  ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(),
-                  &cancel_event);
+  int ret = epoll_ctl(epoll_fd_.Get(), EPOLL_CTL_ADD, cancel_event_fd_.Get(),
+                      &cancel_event);
   CHECK_EQ(ret, 0)
       << "Endpoint::Endpoint: Failed to add cancel event fd to epoll fd: "
       << strerror(errno);
-  socket_fd_ = std::move(fd);
+  socket_fd_ = std::move(socket_fd);
 }
 
 void* Endpoint::AllocateMessageState() { return new MessageState; }
@@ -199,6 +209,9 @@
 }
 
 Status<void> Endpoint::AcceptConnection(Message* message) {
+  if (!socket_fd_)
+    return ErrorStatus(EBADF);
+
   sockaddr_un remote;
   socklen_t addrlen = sizeof(remote);
   LocalHandle channel_fd{accept4(socket_fd_.Get(),
@@ -515,7 +528,7 @@
     return ErrorStatus{ESHUTDOWN};
   }
 
-  if (event.data.fd == socket_fd_.Get()) {
+  if (socket_fd_ && event.data.fd == socket_fd_.Get()) {
     auto status = AcceptConnection(message);
     if (!status)
       return status;
@@ -680,6 +693,23 @@
       new Endpoint(endpoint_path, blocking, false));
 }
 
+std::unique_ptr<Endpoint> Endpoint::CreateFromSocketFd(LocalHandle socket_fd) {
+  return std::unique_ptr<Endpoint>(new Endpoint(std::move(socket_fd)));
+}
+
+Status<void> Endpoint::RegisterNewChannelForTests(LocalHandle channel_fd) {
+  int optval = 1;
+  if (setsockopt(channel_fd.Get(), SOL_SOCKET, SO_PASSCRED, &optval,
+                 sizeof(optval)) == -1) {
+    ALOGE(
+        "Endpoint::RegisterNewChannelForTests: Failed to enable the receiving"
+        "of the credentials for channel %d: %s",
+        channel_fd.Get(), strerror(errno));
+    return ErrorStatus(errno);
+  }
+  return OnNewChannel(std::move(channel_fd));
+}
+
 }  // namespace uds
 }  // namespace pdx
 }  // namespace android