binder: Use RpcTransport

- after accept() / connect(), call sslAccept() /
  sslConnect(), respectively.
- replace ::send() / ::recv() with RpcTransport::
  send() / recv() / peek() accordingly.

Also refacator binderRpcTest to prepare for TLS implementation.

Test: TH
Test: binderRpcTest
Bug: 190868302

Change-Id: I809345c59a467cd219ebcec7a9db3a3b7776a601
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index 254b99c..6a22913 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -30,6 +30,7 @@
 #include <android_runtime/vm.h>
 #include <binder/Parcel.h>
 #include <binder/RpcServer.h>
+#include <binder/RpcTransportRaw.h>
 #include <binder/Stability.h>
 #include <jni.h>
 #include <utils/String8.h>
@@ -46,7 +47,8 @@
 
 using base::unique_fd;
 
-RpcSession::RpcSession() {
+RpcSession::RpcSession(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory)
+      : mRpcTransportCtxFactory(std::move(rpcTransportCtxFactory)) {
     LOG_RPC_DETAIL("RpcSession created %p", this);
 
     mState = std::make_unique<RpcState>();
@@ -59,8 +61,11 @@
                         "Should not be able to destroy a session with servers in use.");
 }
 
-sp<RpcSession> RpcSession::make() {
-    return sp<RpcSession>::make();
+sp<RpcSession> RpcSession::make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory) {
+    // Default is without TLS.
+    if (rpcTransportCtxFactory == nullptr)
+        rpcTransportCtxFactory = RpcTransportCtxFactoryRaw::make();
+    return sp<RpcSession>::make(std::move(rpcTransportCtxFactory));
 }
 
 void RpcSession::setMaxThreads(size_t threads) {
@@ -122,6 +127,7 @@
 }
 
 bool RpcSession::addNullDebuggingClient() {
+    // Note: only works on raw sockets.
     unique_fd serverFd(TEMP_FAILURE_RETRY(open("/dev/null", O_WRONLY | O_CLOEXEC)));
 
     if (serverFd == -1) {
@@ -129,7 +135,17 @@
         return false;
     }
 
-    return addOutgoingConnection(std::move(serverFd), false);
+    auto ctx = mRpcTransportCtxFactory->newClientCtx();
+    if (ctx == nullptr) {
+        ALOGE("Unable to create RpcTransportCtx for null debugging client");
+        return false;
+    }
+    auto server = ctx->newTransport(std::move(serverFd));
+    if (server == nullptr) {
+        ALOGE("Unable to set up RpcTransport");
+        return false;
+    }
+    return addOutgoingConnection(std::move(server), false);
 }
 
 sp<IBinder> RpcSession::getRootObject() {
@@ -205,6 +221,10 @@
     return mWrite == -1;
 }
 
+status_t RpcSession::FdTrigger::triggerablePoll(RpcTransport* rpcTransport, int16_t event) {
+    return triggerablePoll(rpcTransport->pollSocket(), event);
+}
+
 status_t RpcSession::FdTrigger::triggerablePoll(base::borrowed_fd fd, int16_t event) {
     while (true) {
         pollfd pfd[]{{.fd = fd.get(), .events = static_cast<int16_t>(event), .revents = 0},
@@ -223,28 +243,30 @@
     }
 }
 
-status_t RpcSession::FdTrigger::interruptableWriteFully(base::borrowed_fd fd, const void* data,
-                                                        size_t size) {
+status_t RpcSession::FdTrigger::interruptableWriteFully(RpcTransport* rpcTransport,
+                                                        const void* data, size_t size) {
     const uint8_t* buffer = reinterpret_cast<const uint8_t*>(data);
     const uint8_t* end = buffer + size;
 
     MAYBE_WAIT_IN_FLAKE_MODE;
 
     status_t status;
-    while ((status = triggerablePoll(fd, POLLOUT)) == OK) {
-        ssize_t writeSize = TEMP_FAILURE_RETRY(send(fd.get(), buffer, end - buffer, MSG_NOSIGNAL));
-        if (writeSize == 0) return DEAD_OBJECT;
-
-        if (writeSize < 0) {
-            return -errno;
+    while ((status = triggerablePoll(rpcTransport, POLLOUT)) == OK) {
+        auto writeSize = rpcTransport->send(buffer, end - buffer);
+        if (!writeSize.ok()) {
+            LOG_RPC_DETAIL("RpcTransport::send(): %s", writeSize.error().message().c_str());
+            return writeSize.error().code() == 0 ? UNKNOWN_ERROR : -writeSize.error().code();
         }
-        buffer += writeSize;
+
+        if (*writeSize == 0) return DEAD_OBJECT;
+
+        buffer += *writeSize;
         if (buffer == end) return OK;
     }
     return status;
 }
 
-status_t RpcSession::FdTrigger::interruptableReadFully(base::borrowed_fd fd, void* data,
+status_t RpcSession::FdTrigger::interruptableReadFully(RpcTransport* rpcTransport, void* data,
                                                        size_t size) {
     uint8_t* buffer = reinterpret_cast<uint8_t*>(data);
     uint8_t* end = buffer + size;
@@ -252,14 +274,16 @@
     MAYBE_WAIT_IN_FLAKE_MODE;
 
     status_t status;
-    while ((status = triggerablePoll(fd, POLLIN)) == OK) {
-        ssize_t readSize = TEMP_FAILURE_RETRY(recv(fd.get(), buffer, end - buffer, MSG_NOSIGNAL));
-        if (readSize == 0) return DEAD_OBJECT; // EOF
-
-        if (readSize < 0) {
-            return -errno;
+    while ((status = triggerablePoll(rpcTransport, POLLIN)) == OK) {
+        auto readSize = rpcTransport->recv(buffer, end - buffer);
+        if (!readSize.ok()) {
+            LOG_RPC_DETAIL("RpcTransport::recv(): %s", readSize.error().message().c_str());
+            return readSize.error().code() == 0 ? UNKNOWN_ERROR : -readSize.error().code();
         }
-        buffer += readSize;
+
+        if (*readSize == 0) return DEAD_OBJECT; // EOF
+
+        buffer += *readSize;
         if (buffer == end) return OK;
     }
     return status;
@@ -312,10 +336,11 @@
     }
 }
 
-RpcSession::PreJoinSetupResult RpcSession::preJoinSetup(base::unique_fd fd) {
+RpcSession::PreJoinSetupResult RpcSession::preJoinSetup(
+        std::unique_ptr<RpcTransport> rpcTransport) {
     // must be registered to allow arbitrary client code executing commands to
     // be able to do nested calls (we can't only read from it)
-    sp<RpcConnection> connection = assignIncomingConnectionToThisThread(std::move(fd));
+    sp<RpcConnection> connection = assignIncomingConnectionToThisThread(std::move(rpcTransport));
 
     status_t status;
 
@@ -520,6 +545,22 @@
                   strerror(savedErrno));
             return false;
         }
+        LOG_RPC_DETAIL("Socket at %s client with fd %d", addr.toString().c_str(), serverFd.get());
+
+        auto ctx = mRpcTransportCtxFactory->newClientCtx();
+        if (ctx == nullptr) {
+            ALOGE("Unable to create client RpcTransportCtx with %s sockets",
+                  mRpcTransportCtxFactory->toCString());
+            return false;
+        }
+        auto server = ctx->newTransport(std::move(serverFd));
+        if (server == nullptr) {
+            ALOGE("Unable to set up RpcTransport for %s", addr.toString().c_str());
+            return false;
+        }
+
+        LOG_RPC_DETAIL("Socket at %s client with RpcTransport %p", addr.toString().c_str(),
+                       server.get());
 
         RpcConnectionHeader header{
                 .version = mProtocolVersion.value_or(RPC_WIRE_PROTOCOL_VERSION),
@@ -529,19 +570,24 @@
 
         if (incoming) header.options |= RPC_CONNECTION_OPTION_INCOMING;
 
-        if (sizeof(header) != TEMP_FAILURE_RETRY(write(serverFd.get(), &header, sizeof(header)))) {
-            int savedErrno = errno;
+        auto sentHeader = server->send(&header, sizeof(header));
+        if (!sentHeader.ok()) {
             ALOGE("Could not write connection header to socket at %s: %s", addr.toString().c_str(),
-                  strerror(savedErrno));
+                  sentHeader.error().message().c_str());
+            return false;
+        }
+        if (*sentHeader != sizeof(header)) {
+            ALOGE("Could not write connection header to socket at %s: sent %zd bytes, expected %zd",
+                  addr.toString().c_str(), *sentHeader, sizeof(header));
             return false;
         }
 
-        LOG_RPC_DETAIL("Socket at %s client with fd %d", addr.toString().c_str(), serverFd.get());
+        LOG_RPC_DETAIL("Socket at %s client: header sent", addr.toString().c_str());
 
         if (incoming) {
-            return addIncomingConnection(std::move(serverFd));
+            return addIncomingConnection(std::move(server));
         } else {
-            return addOutgoingConnection(std::move(serverFd), true);
+            return addOutgoingConnection(std::move(server), true);
         }
     }
 
@@ -549,7 +595,7 @@
     return false;
 }
 
-bool RpcSession::addIncomingConnection(unique_fd fd) {
+bool RpcSession::addIncomingConnection(std::unique_ptr<RpcTransport> rpcTransport) {
     std::mutex mutex;
     std::condition_variable joinCv;
     std::unique_lock<std::mutex> lock(mutex);
@@ -558,13 +604,13 @@
     bool ownershipTransferred = false;
     thread = std::thread([&]() {
         std::unique_lock<std::mutex> threadLock(mutex);
-        unique_fd movedFd = std::move(fd);
+        std::unique_ptr<RpcTransport> movedRpcTransport = std::move(rpcTransport);
         // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
         sp<RpcSession> session = thiz;
         session->preJoinThreadOwnership(std::move(thread));
 
         // only continue once we have a response or the connection fails
-        auto setupResult = session->preJoinSetup(std::move(movedFd));
+        auto setupResult = session->preJoinSetup(std::move(movedRpcTransport));
 
         ownershipTransferred = true;
         threadLock.unlock();
@@ -578,7 +624,7 @@
     return true;
 }
 
-bool RpcSession::addOutgoingConnection(unique_fd fd, bool init) {
+bool RpcSession::addOutgoingConnection(std::unique_ptr<RpcTransport> rpcTransport, bool init) {
     sp<RpcConnection> connection = sp<RpcConnection>::make();
     {
         std::lock_guard<std::mutex> _l(mMutex);
@@ -591,7 +637,7 @@
             if (mShutdownTrigger == nullptr) return false;
         }
 
-        connection->fd = std::move(fd);
+        connection->rpcTransport = std::move(rpcTransport);
         connection->exclusiveTid = gettid();
         mOutgoingConnections.push_back(connection);
     }
@@ -626,7 +672,8 @@
     return true;
 }
 
-sp<RpcSession::RpcConnection> RpcSession::assignIncomingConnectionToThisThread(unique_fd fd) {
+sp<RpcSession::RpcConnection> RpcSession::assignIncomingConnectionToThisThread(
+        std::unique_ptr<RpcTransport> rpcTransport) {
     std::lock_guard<std::mutex> _l(mMutex);
 
     // Don't accept any more connections, some have shutdown. Usually this
@@ -638,7 +685,7 @@
     }
 
     sp<RpcConnection> session = sp<RpcConnection>::make();
-    session->fd = std::move(fd);
+    session->rpcTransport = std::move(rpcTransport);
     session->exclusiveTid = gettid();
 
     mIncomingConnections.push_back(session);