Merge changes from topic "binder-tls-trigger" am: 53195f9d63 am: dd326395b5 am: 1fbb5b6fc4 am: 5fccda1352

Original change: https://android-review.googlesource.com/c/platform/frameworks/native/+/1825993

Change-Id: I57a6b2267c01f34ce75fc40b9c840dbe2f1b5511
diff --git a/libs/binder/RpcTransportTls.cpp b/libs/binder/RpcTransportTls.cpp
index d40cfc8..63f9339 100644
--- a/libs/binder/RpcTransportTls.cpp
+++ b/libs/binder/RpcTransportTls.cpp
@@ -347,7 +347,7 @@
         ALOGE("%s: %s", __PRETTY_FUNCTION__, ret.error().message().c_str());
         return ret.error().code() == 0 ? UNKNOWN_ERROR : -ret.error().code();
     }
-    return OK;
+    return *ret ? -ECANCELED : OK;
 }
 
 status_t RpcTransportTls::interruptableWriteFully(FdTrigger* fdTrigger, const void* data,
diff --git a/libs/binder/tests/binderRpcTest.cpp b/libs/binder/tests/binderRpcTest.cpp
index a4e37ad..2fd63a3 100644
--- a/libs/binder/tests/binderRpcTest.cpp
+++ b/libs/binder/tests/binderRpcTest.cpp
@@ -53,6 +53,7 @@
 #include "RpcCertificateVerifierSimple.h"
 
 using namespace std::chrono_literals;
+using namespace std::placeholders;
 using testing::AssertionFailure;
 using testing::AssertionResult;
 using testing::AssertionSuccess;
@@ -1444,7 +1445,7 @@
                 PrintToString(certificateFormat);
     }
     void TearDown() override {
-        for (auto& server : mServers) server->shutdown();
+        for (auto& server : mServers) server->shutdownAndWait();
     }
 
     // A server that handles client socket connections.
@@ -1452,7 +1453,7 @@
     public:
         explicit Server() {}
         Server(Server&&) = default;
-        ~Server() { shutdown(); }
+        ~Server() { shutdownAndWait(); }
         [[nodiscard]] AssertionResult setUp() {
             auto [socketType, rpcSecurity, certificateFormat] = GetParam();
             auto rpcServer = RpcServer::make(newFactory(rpcSecurity));
@@ -1536,17 +1537,17 @@
             ASSERT_TRUE(acceptedFd.ok());
             auto serverTransport = mCtx->newTransport(std::move(acceptedFd), mFdTrigger.get());
             if (serverTransport == nullptr) return; // handshake failed
-            std::string message(kMessage);
-            ASSERT_EQ(OK,
-                      serverTransport->interruptableWriteFully(mFdTrigger.get(), message.data(),
-                                                               message.size()));
+            ASSERT_TRUE(mPostConnect(serverTransport.get(), mFdTrigger.get()));
         }
-        void shutdown() {
-            mFdTrigger->trigger();
-            if (mThread != nullptr) {
-                mThread->join();
-                mThread = nullptr;
-            }
+        void shutdownAndWait() {
+            shutdown();
+            join();
+        }
+        void shutdown() { mFdTrigger->trigger(); }
+
+        void setPostConnect(
+                std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> fn) {
+            mPostConnect = std::move(fn);
         }
 
     private:
@@ -1558,6 +1559,26 @@
         std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier =
                 std::make_shared<RpcCertificateVerifierSimple>();
         bool mSetup = false;
+        // The function invoked after connection and handshake. By default, it is
+        // |defaultPostConnect| that sends |kMessage| to the client.
+        std::function<AssertionResult(RpcTransport*, FdTrigger* fdTrigger)> mPostConnect =
+                Server::defaultPostConnect;
+
+        void join() {
+            if (mThread != nullptr) {
+                mThread->join();
+                mThread = nullptr;
+            }
+        }
+
+        static AssertionResult defaultPostConnect(RpcTransport* serverTransport,
+                                                  FdTrigger* fdTrigger) {
+            std::string message(kMessage);
+            auto status = serverTransport->interruptableWriteFully(fdTrigger, message.data(),
+                                                                   message.size());
+            if (status != OK) return AssertionFailure() << statusToString(status);
+            return AssertionSuccess();
+        }
     };
 
     class Client {
@@ -1566,8 +1587,6 @@
         Client(Client&&) = default;
         [[nodiscard]] AssertionResult setUp() {
             auto [socketType, rpcSecurity, certificateFormat] = GetParam();
-            mFd = mConnectToServer();
-            if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server";
             mFdTrigger = FdTrigger::make();
             mCtx = newFactory(rpcSecurity, mCertVerifier)->newClientCtx();
             if (mCtx == nullptr) return AssertionFailure() << "newClientCtx";
@@ -1577,24 +1596,35 @@
         std::shared_ptr<RpcCertificateVerifierSimple> getCertVerifier() const {
             return mCertVerifier;
         }
+        // connect() and do handshake
+        bool setUpTransport() {
+            mFd = mConnectToServer();
+            if (!mFd.ok()) return AssertionFailure() << "Cannot connect to server";
+            mClientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get());
+            return mClientTransport != nullptr;
+        }
+        AssertionResult readMessage(const std::string& expectedMessage = kMessage) {
+            LOG_ALWAYS_FATAL_IF(mClientTransport == nullptr, "setUpTransport not called or failed");
+            std::string readMessage(expectedMessage.size(), '\0');
+            status_t readStatus =
+                    mClientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(),
+                                                             readMessage.size());
+            if (readStatus != OK) {
+                return AssertionFailure() << statusToString(readStatus);
+            }
+            if (readMessage != expectedMessage) {
+                return AssertionFailure()
+                        << "Expected " << expectedMessage << ", actual " << readMessage;
+            }
+            return AssertionSuccess();
+        }
         void run(bool handshakeOk = true, bool readOk = true) {
-            auto clientTransport = mCtx->newTransport(std::move(mFd), mFdTrigger.get());
-            if (clientTransport == nullptr) {
+            if (!setUpTransport()) {
                 ASSERT_FALSE(handshakeOk) << "newTransport returns nullptr, but it shouldn't";
                 return;
             }
             ASSERT_TRUE(handshakeOk) << "newTransport does not return nullptr, but it should";
-            std::string expectedMessage(kMessage);
-            std::string readMessage(expectedMessage.size(), '\0');
-            status_t readStatus =
-                    clientTransport->interruptableReadFully(mFdTrigger.get(), readMessage.data(),
-                                                            readMessage.size());
-            if (readOk) {
-                ASSERT_EQ(OK, readStatus);
-                ASSERT_EQ(readMessage, expectedMessage);
-            } else {
-                ASSERT_NE(OK, readStatus);
-            }
+            ASSERT_EQ(readOk, readMessage());
         }
 
     private:
@@ -1604,6 +1634,7 @@
         std::unique_ptr<RpcTransportCtx> mCtx;
         std::shared_ptr<RpcCertificateVerifierSimple> mCertVerifier =
                 std::make_shared<RpcCertificateVerifierSimple>();
+        std::unique_ptr<RpcTransport> mClientTransport;
     };
 
     // Make A trust B.
@@ -1729,6 +1760,68 @@
     maliciousClient.run(true, readOk);
 }
 
+TEST_P(RpcTransportTest, Trigger) {
+    std::string msg2 = ", world!";
+    std::mutex writeMutex;
+    std::condition_variable writeCv;
+    bool shouldContinueWriting = false;
+    auto serverPostConnect = [&](RpcTransport* serverTransport, FdTrigger* fdTrigger) {
+        std::string message(kMessage);
+        auto status =
+                serverTransport->interruptableWriteFully(fdTrigger, message.data(), message.size());
+        if (status != OK) return AssertionFailure() << statusToString(status);
+
+        {
+            std::unique_lock<std::mutex> lock(writeMutex);
+            if (!writeCv.wait_for(lock, 3s, [&] { return shouldContinueWriting; })) {
+                return AssertionFailure() << "write barrier not cleared in time!";
+            }
+        }
+
+        status = serverTransport->interruptableWriteFully(fdTrigger, msg2.data(), msg2.size());
+        if (status != -ECANCELED)
+            return AssertionFailure() << "When FdTrigger is shut down, interruptableWriteFully "
+                                         "should return -ECANCELLED, but it is "
+                                      << statusToString(status);
+        return AssertionSuccess();
+    };
+
+    auto server = mServers.emplace_back(std::make_unique<Server>()).get();
+    ASSERT_TRUE(server->setUp());
+
+    // Set up client
+    Client client(server->getConnectToServerFn());
+    ASSERT_TRUE(client.setUp());
+
+    // Exchange keys
+    ASSERT_EQ(OK, trust(&client, server));
+    ASSERT_EQ(OK, trust(server, &client));
+
+    server->setPostConnect(serverPostConnect);
+
+    // Start server
+    server->start();
+    // connect() to server and do handshake
+    ASSERT_TRUE(client.setUpTransport());
+    // read the first message. This confirms that server has finished handshake and start handling
+    // client fd. Server thread should pause at waitForWriteBarrier.
+    ASSERT_TRUE(client.readMessage(kMessage));
+    // Trigger server shutdown after server starts handling client FD. This ensures that the second
+    // write is on an FdTrigger that has been shut down.
+    server->shutdown();
+    // Continues server thread to write the second message.
+    {
+        std::unique_lock<std::mutex> lock(writeMutex);
+        shouldContinueWriting = true;
+        lock.unlock();
+        writeCv.notify_all();
+    }
+    // After this line, server thread unblocks and attempts to write the second message, but
+    // shutdown is triggered, so write should failed with -ECANCELLED. See |serverPostConnect|.
+    // On the client side, second read fails with DEAD_OBJECT
+    ASSERT_FALSE(client.readMessage(msg2));
+}
+
 std::vector<RpcCertificateFormat> testRpcCertificateFormats() {
     return {
             RpcCertificateFormat::PEM,