binder: RpcServer / RpcSession add API for certs.

These APIs call into RpcTransportCtx::getCertificate
and RpcTransportClientCtx::addTrustedPeerCertificate,
respectively.

For RpcSession, peer (server) certificates are fixed when
it is constructed.

Test: binderRpcTest
Bug: 195166979
Change-Id: I0d1bd93042895aeb3ab1de4fe6b9d90e73d0d116
diff --git a/libs/binder/RpcServer.cpp b/libs/binder/RpcServer.cpp
index a20445b..ad9ba96 100644
--- a/libs/binder/RpcServer.cpp
+++ b/libs/binder/RpcServer.cpp
@@ -39,8 +39,7 @@
 using base::ScopeGuard;
 using base::unique_fd;
 
-RpcServer::RpcServer(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory)
-      : mRpcTransportCtxFactory(std::move(rpcTransportCtxFactory)) {}
+RpcServer::RpcServer(std::unique_ptr<RpcTransportCtx> ctx) : mCtx(std::move(ctx)) {}
 RpcServer::~RpcServer() {
     (void)shutdown();
 }
@@ -49,7 +48,9 @@
     // Default is without TLS.
     if (rpcTransportCtxFactory == nullptr)
         rpcTransportCtxFactory = RpcTransportCtxFactoryRaw::make();
-    return sp<RpcServer>::make(std::move(rpcTransportCtxFactory));
+    auto ctx = rpcTransportCtxFactory->newServerCtx();
+    if (ctx == nullptr) return nullptr;
+    return sp<RpcServer>::make(std::move(ctx));
 }
 
 void RpcServer::iUnderstandThisCodeIsExperimentalAndIWillNotUseItInProduction() {
@@ -138,6 +139,20 @@
     return ret;
 }
 
+std::string RpcServer::getCertificate(CertificateFormat format) {
+    std::lock_guard<std::mutex> _l(mLock);
+    return mCtx->getCertificate(format);
+}
+
+status_t RpcServer::addTrustedPeerCertificate(CertificateFormat format, std::string_view cert) {
+    std::lock_guard<std::mutex> _l(mLock);
+    // Ensure that join thread is not running or shutdown trigger is not set up. In either case,
+    // it means there are child threads running. It is invalid to add trusted peer certificates
+    // after join thread and/or child threads are running to avoid race condition.
+    if (mJoinThreadRunning || mShutdownTrigger != nullptr) return INVALID_OPERATION;
+    return mCtx->addTrustedPeerCertificate(format, cert);
+}
+
 static void joinRpcServer(sp<RpcServer>&& thiz) {
     thiz->join();
 }
@@ -159,10 +174,6 @@
         mJoinThreadRunning = true;
         mShutdownTrigger = FdTrigger::make();
         LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr, "Cannot create join signaler");
-
-        mCtx = mRpcTransportCtxFactory->newServerCtx();
-        LOG_ALWAYS_FATAL_IF(mCtx == nullptr, "Unable to create RpcTransportCtx with %s sockets",
-                            mRpcTransportCtxFactory->toCString());
     }
 
     status_t status;
@@ -229,7 +240,6 @@
     LOG_RPC_DETAIL("Finished waiting on shutdown.");
 
     mShutdownTrigger = nullptr;
-    mCtx = nullptr;
     return true;
 }
 
diff --git a/libs/binder/RpcSession.cpp b/libs/binder/RpcSession.cpp
index 4c47005..22a7782 100644
--- a/libs/binder/RpcSession.cpp
+++ b/libs/binder/RpcSession.cpp
@@ -49,8 +49,7 @@
 
 using base::unique_fd;
 
-RpcSession::RpcSession(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory)
-      : mRpcTransportCtxFactory(std::move(rpcTransportCtxFactory)) {
+RpcSession::RpcSession(std::unique_ptr<RpcTransportCtx> ctx) : mCtx(std::move(ctx)) {
     LOG_RPC_DETAIL("RpcSession created %p", this);
 
     mState = std::make_unique<RpcState>();
@@ -63,11 +62,26 @@
                         "Should not be able to destroy a session with servers in use.");
 }
 
-sp<RpcSession> RpcSession::make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory) {
+sp<RpcSession> RpcSession::make() {
     // Default is without TLS.
-    if (rpcTransportCtxFactory == nullptr)
-        rpcTransportCtxFactory = RpcTransportCtxFactoryRaw::make();
-    return sp<RpcSession>::make(std::move(rpcTransportCtxFactory));
+    return make(RpcTransportCtxFactoryRaw::make(), std::nullopt, std::nullopt);
+}
+
+sp<RpcSession> RpcSession::make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory,
+                                std::optional<CertificateFormat> serverCertificateFormat,
+                                std::optional<std::string> serverCertificate) {
+    auto ctx = rpcTransportCtxFactory->newClientCtx();
+    if (ctx == nullptr) return nullptr;
+    LOG_ALWAYS_FATAL_IF(serverCertificateFormat.has_value() != serverCertificate.has_value());
+    if (serverCertificateFormat.has_value() && serverCertificate.has_value()) {
+        status_t status =
+                ctx->addTrustedPeerCertificate(*serverCertificateFormat, *serverCertificate);
+        if (status != OK) {
+            ALOGE("Cannot add trusted server certificate: %s", statusToString(status).c_str());
+            return nullptr;
+        }
+    }
+    return sp<RpcSession>::make(std::move(ctx));
 }
 
 void RpcSession::setMaxThreads(size_t threads) {
@@ -155,12 +169,7 @@
         return -savedErrno;
     }
 
-    auto ctx = mRpcTransportCtxFactory->newClientCtx();
-    if (ctx == nullptr) {
-        ALOGE("Unable to create RpcTransportCtx for null debugging client");
-        return NO_MEMORY;
-    }
-    auto server = ctx->newTransport(std::move(serverFd), mShutdownTrigger.get());
+    auto server = mCtx->newTransport(std::move(serverFd), mShutdownTrigger.get());
     if (server == nullptr) {
         ALOGE("Unable to set up RpcTransport");
         return UNKNOWN_ERROR;
@@ -529,15 +538,9 @@
 status_t RpcSession::initAndAddConnection(unique_fd fd, const RpcAddress& sessionId,
                                           bool incoming) {
     LOG_ALWAYS_FATAL_IF(mShutdownTrigger == nullptr);
-    auto ctx = mRpcTransportCtxFactory->newClientCtx();
-    if (ctx == nullptr) {
-        ALOGE("Unable to create client RpcTransportCtx with %s sockets",
-              mRpcTransportCtxFactory->toCString());
-        return NO_MEMORY;
-    }
-    auto server = ctx->newTransport(std::move(fd), mShutdownTrigger.get());
+    auto server = mCtx->newTransport(std::move(fd), mShutdownTrigger.get());
     if (server == nullptr) {
-        ALOGE("Unable to set up RpcTransport in %s context", mRpcTransportCtxFactory->toCString());
+        ALOGE("%s: Unable to set up RpcTransport", __PRETTY_FUNCTION__);
         return UNKNOWN_ERROR;
     }
 
@@ -692,6 +695,10 @@
     return false;
 }
 
+std::string RpcSession::getCertificate(CertificateFormat format) {
+    return mCtx->getCertificate(format);
+}
+
 status_t RpcSession::ExclusiveConnection::find(const sp<RpcSession>& session, ConnectionUse use,
                                                ExclusiveConnection* connection) {
     connection->mSession = session;
diff --git a/libs/binder/include/binder/RpcServer.h b/libs/binder/include/binder/RpcServer.h
index bf3e7e0..d0e4e27 100644
--- a/libs/binder/include/binder/RpcServer.h
+++ b/libs/binder/include/binder/RpcServer.h
@@ -134,6 +134,17 @@
     sp<IBinder> getRootObject();
 
     /**
+     * See RpcTransportCtx::getCertificate
+     */
+    std::string getCertificate(CertificateFormat);
+
+    /**
+     * See RpcTransportCtx::addTrustedPeerCertificate.
+     * Thread-safe. This is only possible before the server is join()-ing.
+     */
+    status_t addTrustedPeerCertificate(CertificateFormat, std::string_view cert);
+
+    /**
      * Runs join() in a background thread. Immediately returns.
      */
     void start();
@@ -170,7 +181,7 @@
 
 private:
     friend sp<RpcServer>;
-    explicit RpcServer(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory);
+    explicit RpcServer(std::unique_ptr<RpcTransportCtx> ctx);
 
     void onSessionAllIncomingThreadsEnded(const sp<RpcSession>& session) override;
     void onSessionIncomingThreadEnded() override;
@@ -178,7 +189,7 @@
     static void establishConnection(sp<RpcServer>&& server, base::unique_fd clientFd);
     status_t setupSocketServer(const RpcSocketAddress& address);
 
-    const std::unique_ptr<RpcTransportCtxFactory> mRpcTransportCtxFactory;
+    const std::unique_ptr<RpcTransportCtx> mCtx;
     bool mAgreedExperimental = false;
     size_t mMaxThreads = 1;
     std::optional<uint32_t> mProtocolVersion;
@@ -193,7 +204,6 @@
     std::map<RpcAddress, sp<RpcSession>> mSessions;
     std::unique_ptr<FdTrigger> mShutdownTrigger;
     std::condition_variable mShutdownCv;
-    std::unique_ptr<RpcTransportCtx> mCtx;
 };
 
 } // namespace android
diff --git a/libs/binder/include/binder/RpcSession.h b/libs/binder/include/binder/RpcSession.h
index 6e6eb74..d92af0a 100644
--- a/libs/binder/include/binder/RpcSession.h
+++ b/libs/binder/include/binder/RpcSession.h
@@ -51,8 +51,15 @@
  */
 class RpcSession final : public virtual RefBase {
 public:
-    static sp<RpcSession> make(
-            std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory = nullptr);
+    // Create an RpcSession with default configuration (raw sockets).
+    static sp<RpcSession> make();
+
+    // Create an RpcSession with the given configuration. |serverCertificateFormat| and
+    // |serverCertificate| must have values or be nullopt simultaneously. If they have values, set
+    // server certificate.
+    static sp<RpcSession> make(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory,
+                               std::optional<CertificateFormat> serverCertificateFormat,
+                               std::optional<std::string> serverCertificate);
 
     /**
      * Set the maximum number of threads allowed to be made (for things like callbacks).
@@ -125,6 +132,11 @@
     status_t getRemoteMaxThreads(size_t* maxThreads);
 
     /**
+     * See RpcTransportCtx::getCertificate
+     */
+    std::string getCertificate(CertificateFormat);
+
+    /**
      * Shuts down the service.
      *
      * For client sessions, wait can be true or false. For server sessions,
@@ -159,7 +171,7 @@
     friend sp<RpcSession>;
     friend RpcServer;
     friend RpcState;
-    explicit RpcSession(std::unique_ptr<RpcTransportCtxFactory> rpcTransportCtxFactory);
+    explicit RpcSession(std::unique_ptr<RpcTransportCtx> ctx);
 
     class EventListener : public virtual RefBase {
     public:
@@ -259,7 +271,7 @@
         bool mReentrant = false;
     };
 
-    const std::unique_ptr<RpcTransportCtxFactory> mRpcTransportCtxFactory;
+    const std::unique_ptr<RpcTransportCtx> mCtx;
 
     // On the other side of a session, for each of mOutgoingConnections here, there should
     // be one of mIncomingConnections on the other side (and vice versa).
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index 35db444..7c405d3 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -522,7 +522,8 @@
         status_t status;
 
         for (size_t i = 0; i < options.numSessions; i++) {
-            sp<RpcSession> session = RpcSession::make(newFactory(rpcSecurity));
+            sp<RpcSession> session =
+                    RpcSession::make(newFactory(rpcSecurity), std::nullopt, std::nullopt);
             session->setMaxThreads(options.numIncomingConnections);
 
             switch (socketType) {
@@ -1207,7 +1208,8 @@
     }
     server->start();
 
-    sp<RpcSession> session = RpcSession::make(RpcTransportCtxFactoryRaw::make());
+    sp<RpcSession> session =
+            RpcSession::make(RpcTransportCtxFactoryRaw::make(), std::nullopt, std::nullopt);
     status_t status = session->setupVsockClient(VMADDR_CID_LOCAL, vsockPort);
     while (!server->shutdown()) usleep(10000);
     ALOGE("Detected vsock loopback supported: %s", statusToString(status).c_str());
diff --git a/libs/binder/tests/parcel_fuzzer/random_parcel.cpp b/libs/binder/tests/parcel_fuzzer/random_parcel.cpp
index 8bf04cc..7fd9f6b 100644
--- a/libs/binder/tests/parcel_fuzzer/random_parcel.cpp
+++ b/libs/binder/tests/parcel_fuzzer/random_parcel.cpp
@@ -36,7 +36,8 @@
 
 void fillRandomParcel(Parcel* p, FuzzedDataProvider&& provider) {
     if (provider.ConsumeBool()) {
-        auto session = RpcSession::make(RpcTransportCtxFactoryRaw::make());
+        auto session =
+                RpcSession::make(RpcTransportCtxFactoryRaw::make(), std::nullopt, std::nullopt);
         CHECK_EQ(OK, session->addNullDebuggingClient());
         p->markForRpc(session);
         fillRandomParcelData(p, std::move(provider));