snapuserd: Refactor daemon/server.

This is in preparation for moving to a traditional client/server model
where clients stay connected and the server multiplexes multiple
connections.

Client has been renamed to DmUserClient to differentiate it from local
socket clients.

poll() responsibilities have been moved into SnapuserdServer. In
addition, the server now tracks all open clients and polls them
together with the listen socket.

SnapuserDaemon is now only responsible for signal masking. These two
classes can probably be merged together - I didn't do that here because
the patch was already large.

Bug: 168554689
Test: manual test
Change-Id: Ibc06f6287d49e832a8e25dd936ec07747a1b0555
diff --git a/fs_mgr/libsnapshot/Android.bp b/fs_mgr/libsnapshot/Android.bp
index b239f31..84c93b1 100644
--- a/fs_mgr/libsnapshot/Android.bp
+++ b/fs_mgr/libsnapshot/Android.bp
@@ -402,7 +402,7 @@
     ],
     srcs: [
 	"snapuserd_server.cpp",
-        "snapuserd.cpp",
+    "snapuserd.cpp",
 	"snapuserd_daemon.cpp",
     ],
 
@@ -554,7 +554,7 @@
         "libbrotli",
         "libgtest",
         "libsnapshot_cow",
-	"libsnapshot_snapuserd",
+        "libsnapshot_snapuserd",
         "libcutils_sockets",
         "libz",
 	"libdm",
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h
index ab2149e..d6713b8 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h
@@ -26,6 +26,9 @@
 static constexpr uint32_t PACKET_SIZE = 512;
 static constexpr uint32_t MAX_CONNECT_RETRY_COUNT = 10;
 
+static constexpr char kSnapuserdSocketFirstStage[] = "snapuserd_first_stage";
+static constexpr char kSnapuserdSocket[] = "snapuserd";
+
 class SnapuserdClient {
   private:
     int sockfd_ = 0;
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_daemon.h b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_daemon.h
index 94542d7..c6779b8 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_daemon.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_daemon.h
@@ -25,22 +25,21 @@
     // The Daemon class is a singleton to avoid
     // instantiating more than once
   public:
+    Daemon() {}
+
     static Daemon& Instance() {
         static Daemon instance;
         return instance;
     }
 
-    int StartServer(std::string socketname);
-    bool IsRunning();
+    bool StartServer(const std::string& socketname);
     void Run();
+    void Interrupt();
 
   private:
-    bool is_running_;
-    std::unique_ptr<struct pollfd> poll_fd_;
     // Signal mask used with ppoll()
     sigset_t signal_mask_;
 
-    Daemon();
     Daemon(Daemon const&) = delete;
     void operator=(Daemon const&) = delete;
 
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_server.h b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_server.h
index a1ebd3a..357acac 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_server.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_server.h
@@ -14,6 +14,8 @@
 
 #pragma once
 
+#include <poll.h>
+
 #include <cstdio>
 #include <cstring>
 #include <functional>
@@ -34,12 +36,11 @@
 enum class DaemonOperations {
     START,
     QUERY,
-    TERMINATING,
     STOP,
     INVALID,
 };
 
-class Client {
+class DmUserHandler {
   private:
     std::unique_ptr<std::thread> threadHandler_;
 
@@ -77,7 +78,15 @@
   private:
     android::base::unique_fd sockfd_;
     bool terminating_;
-    std::vector<std::unique_ptr<Client>> clients_vec_;
+    std::vector<std::unique_ptr<DmUserHandler>> dm_users_;
+    std::vector<struct pollfd> watched_fds_;
+
+    void AddWatchedFd(android::base::borrowed_fd fd);
+    void AcceptClient();
+    bool HandleClient(android::base::borrowed_fd fd, int revents);
+    bool Recv(android::base::borrowed_fd fd, std::string* data);
+    bool Sendmsg(android::base::borrowed_fd fd, const std::string& msg);
+    bool Receivemsg(android::base::borrowed_fd fd, const std::string& msg);
 
     void ThreadStart(std::string cow_device, std::string backing_device,
                      std::string control_device) override;
@@ -92,13 +101,11 @@
 
   public:
     SnapuserdServer() { terminating_ = false; }
+    ~SnapuserdServer();
 
-    int Start(std::string socketname);
-    int AcceptClient();
-    int Receivemsg(int fd);
-    int Sendmsg(int fd, char* msg, size_t len);
-    std::string Recvmsg(int fd, int* ret);
-    android::base::borrowed_fd GetSocketFd() { return sockfd_; }
+    bool Start(const std::string& socketname);
+    bool Run();
+    void Interrupt();
 };
 
 }  // namespace snapshot
diff --git a/fs_mgr/libsnapshot/snapuserd.cpp b/fs_mgr/libsnapshot/snapuserd.cpp
index 62ef1b0..440d0ce 100644
--- a/fs_mgr/libsnapshot/snapuserd.cpp
+++ b/fs_mgr/libsnapshot/snapuserd.cpp
@@ -17,6 +17,7 @@
 #include <csignal>
 
 #include <libsnapshot/snapuserd.h>
+#include <libsnapshot/snapuserd_client.h>
 #include <libsnapshot/snapuserd_daemon.h>
 #include <libsnapshot/snapuserd_server.h>
 
@@ -476,13 +477,13 @@
 bool Snapuserd::Init() {
     backing_store_fd_.reset(open(backing_store_device_.c_str(), O_RDONLY));
     if (backing_store_fd_ < 0) {
-        LOG(ERROR) << "Open Failed: " << backing_store_device_;
+        PLOG(ERROR) << "Open Failed: " << backing_store_device_;
         return false;
     }
 
     cow_fd_.reset(open(cow_device_.c_str(), O_RDWR));
     if (cow_fd_ < 0) {
-        LOG(ERROR) << "Open Failed: " << cow_device_;
+        PLOG(ERROR) << "Open Failed: " << cow_device_;
         return false;
     }
 
@@ -492,7 +493,7 @@
 
     ctrl_fd_.reset(open(control_path.c_str(), O_RDWR));
     if (ctrl_fd_ < 0) {
-        LOG(ERROR) << "Unable to open " << control_path;
+        PLOG(ERROR) << "Unable to open " << control_path;
         return false;
     }
 
@@ -623,7 +624,11 @@
 
     android::snapshot::Daemon& daemon = android::snapshot::Daemon::Instance();
 
-    daemon.StartServer(argv[1]);
+    std::string socket = android::snapshot::kSnapuserdSocket;
+    if (argc >= 2) {
+        socket = argv[1];
+    }
+    daemon.StartServer(socket);
     daemon.Run();
 
     return 0;
diff --git a/fs_mgr/libsnapshot/snapuserd_daemon.cpp b/fs_mgr/libsnapshot/snapuserd_daemon.cpp
index 8e76618..4c8fa57 100644
--- a/fs_mgr/libsnapshot/snapuserd_daemon.cpp
+++ b/fs_mgr/libsnapshot/snapuserd_daemon.cpp
@@ -20,16 +20,12 @@
 namespace android {
 namespace snapshot {
 
-int Daemon::StartServer(std::string socketname) {
-    int ret;
-
-    ret = server_.Start(socketname);
-    if (ret < 0) {
+bool Daemon::StartServer(const std::string& socketname) {
+    if (!server_.Start(socketname)) {
         LOG(ERROR) << "Snapuserd daemon failed to start...";
         exit(EXIT_FAILURE);
     }
-
-    return ret;
+    return true;
 }
 
 void Daemon::MaskAllSignalsExceptIntAndTerm() {
@@ -51,51 +47,26 @@
     }
 }
 
-Daemon::Daemon() {
-    is_running_ = true;
-}
-
-bool Daemon::IsRunning() {
-    return is_running_;
-}
-
 void Daemon::Run() {
-    poll_fd_ = std::make_unique<struct pollfd>();
-    poll_fd_->fd = server_.GetSocketFd().get();
-    poll_fd_->events = POLLIN;
-
     sigfillset(&signal_mask_);
     sigdelset(&signal_mask_, SIGINT);
     sigdelset(&signal_mask_, SIGTERM);
 
     // Masking signals here ensure that after this point, we won't handle INT/TERM
     // until after we call into ppoll()
-    MaskAllSignals();
     signal(SIGINT, Daemon::SignalHandler);
     signal(SIGTERM, Daemon::SignalHandler);
     signal(SIGPIPE, Daemon::SignalHandler);
 
     LOG(DEBUG) << "Snapuserd-server: ready to accept connections";
 
-    while (IsRunning()) {
-        int ret = ppoll(poll_fd_.get(), 1, nullptr, &signal_mask_);
-        MaskAllSignalsExceptIntAndTerm();
+    MaskAllSignalsExceptIntAndTerm();
 
-        if (ret == -1) {
-            PLOG(ERROR) << "Snapuserd:ppoll error";
-            break;
-        }
+    server_.Run();
+}
 
-        if (poll_fd_->revents == POLLIN) {
-            if (server_.AcceptClient() == static_cast<int>(DaemonOperations::STOP)) {
-                Daemon::Instance().is_running_ = false;
-            }
-        }
-
-        // Mask all signals to ensure that is_running_ can't become false between
-        // checking it in the while condition and calling into ppoll()
-        MaskAllSignals();
-    }
+void Daemon::Interrupt() {
+    server_.Interrupt();
 }
 
 void Daemon::SignalHandler(int signal) {
@@ -103,7 +74,7 @@
     switch (signal) {
         case SIGINT:
         case SIGTERM: {
-            Daemon::Instance().is_running_ = false;
+            Daemon::Instance().Interrupt();
             break;
         }
         case SIGPIPE: {
diff --git a/fs_mgr/libsnapshot/snapuserd_server.cpp b/fs_mgr/libsnapshot/snapuserd_server.cpp
index 53101aa..48a3b2a 100644
--- a/fs_mgr/libsnapshot/snapuserd_server.cpp
+++ b/fs_mgr/libsnapshot/snapuserd_server.cpp
@@ -35,12 +35,18 @@
 DaemonOperations SnapuserdServer::Resolveop(std::string& input) {
     if (input == "start") return DaemonOperations::START;
     if (input == "stop") return DaemonOperations::STOP;
-    if (input == "terminate-request") return DaemonOperations::TERMINATING;
     if (input == "query") return DaemonOperations::QUERY;
 
     return DaemonOperations::INVALID;
 }
 
+SnapuserdServer::~SnapuserdServer() {
+    // Close any client sockets that were added via AcceptClient().
+    for (size_t i = 1; i < watched_fds_.size(); i++) {
+        close(watched_fds_[i].fd);
+    }
+}
+
 std::string SnapuserdServer::GetDaemonStatus() {
     std::string msg = "";
 
@@ -67,7 +73,7 @@
                                   std::string control_device) {
     Snapuserd snapd(cow_device, backing_device, control_device);
     if (!snapd.Init()) {
-        PLOG(ERROR) << "Snapuserd: Init failed";
+        LOG(ERROR) << "Snapuserd: Init failed";
         return;
     }
 
@@ -84,158 +90,174 @@
 void SnapuserdServer::ShutdownThreads() {
     StopThreads();
 
-    for (auto& client : clients_vec_) {
+    for (auto& client : dm_users_) {
         auto& th = client->GetThreadHandler();
 
         if (th->joinable()) th->join();
     }
 }
 
-int SnapuserdServer::Sendmsg(int fd, char* msg, size_t size) {
-    int ret = TEMP_FAILURE_RETRY(send(fd, (char*)msg, size, 0));
+bool SnapuserdServer::Sendmsg(android::base::borrowed_fd fd, const std::string& msg) {
+    ssize_t ret = TEMP_FAILURE_RETRY(send(fd.get(), msg.data(), msg.size(), 0));
     if (ret < 0) {
         PLOG(ERROR) << "Snapuserd:server: send() failed";
-        return -1;
+        return false;
     }
 
-    if (ret < size) {
-        PLOG(ERROR) << "Partial data sent";
-        return -1;
+    if (ret < msg.size()) {
+        LOG(ERROR) << "Partial send; expected " << msg.size() << " bytes, sent " << ret;
+        return false;
     }
-
-    return 0;
+    return true;
 }
 
-std::string SnapuserdServer::Recvmsg(int fd, int* ret) {
-    struct timeval tv;
-    fd_set set;
+bool SnapuserdServer::Recv(android::base::borrowed_fd fd, std::string* data) {
     char msg[MAX_PACKET_SIZE];
+    ssize_t rv = TEMP_FAILURE_RETRY(recv(fd.get(), msg, sizeof(msg), 0));
+    if (rv < 0) {
+        PLOG(ERROR) << "recv failed";
+        return false;
+    }
+    *data = std::string(msg, rv);
+    return true;
+}
 
-    tv.tv_sec = 2;
-    tv.tv_usec = 0;
-    FD_ZERO(&set);
-    FD_SET(fd, &set);
-    *ret = select(fd + 1, &set, NULL, NULL, &tv);
-    if (*ret == -1) {  // select failed
-        return {};
-    } else if (*ret == 0) {  // timeout
-        return {};
+bool SnapuserdServer::Receivemsg(android::base::borrowed_fd fd, const std::string& str) {
+    const char delim = ',';
+
+    std::vector<std::string> out;
+    Parsemsg(str, delim, out);
+    DaemonOperations op = Resolveop(out[0]);
+
+    switch (op) {
+        case DaemonOperations::START: {
+            // Message format:
+            // start,<cow_device_path>,<source_device_path>,<control_device>
+            //
+            // Start the new thread which binds to dm-user misc device
+            auto handler = std::make_unique<DmUserHandler>();
+            handler->SetThreadHandler(
+                    std::bind(&SnapuserdServer::ThreadStart, this, out[1], out[2], out[3]));
+            dm_users_.push_back(std::move(handler));
+            return Sendmsg(fd, "success");
+        }
+        case DaemonOperations::STOP: {
+            // Message format: stop
+            //
+            // Stop all the threads gracefully and then shutdown the
+            // main thread
+            SetTerminating();
+            ShutdownThreads();
+            return true;
+        }
+        case DaemonOperations::QUERY: {
+            // Message format: query
+            //
+            // As part of transition, Second stage daemon will be
+            // created before terminating the first stage daemon. Hence,
+            // for a brief period client may have to distiguish between
+            // first stage daemon and second stage daemon.
+            //
+            // Second stage daemon is marked as active and hence will
+            // be ready to receive control message.
+            return Sendmsg(fd, GetDaemonStatus());
+        }
+        default: {
+            LOG(ERROR) << "Received unknown message type from client";
+            Sendmsg(fd, "fail");
+            return false;
+        }
+    }
+}
+
+bool SnapuserdServer::Start(const std::string& socketname) {
+    sockfd_.reset(android_get_control_socket(socketname.c_str()));
+    if (sockfd_ >= 0) {
+        if (listen(sockfd_.get(), 4) < 0) {
+            PLOG(ERROR) << "listen socket failed: " << socketname;
+            return false;
+        }
     } else {
-        *ret = TEMP_FAILURE_RETRY(recv(fd, msg, MAX_PACKET_SIZE, 0));
-        if (*ret < 0) {
-            PLOG(ERROR) << "Snapuserd:server: recv failed";
-            return {};
-        } else if (*ret == 0) {
-            LOG(DEBUG) << "Snapuserd client disconnected";
-            return {};
-        } else {
-            std::string str(msg);
-            return str;
+        sockfd_.reset(socket_local_server(socketname.c_str(), ANDROID_SOCKET_NAMESPACE_RESERVED,
+                                          SOCK_STREAM));
+        if (sockfd_ < 0) {
+            PLOG(ERROR) << "Failed to create server socket " << socketname;
+            return false;
         }
     }
-}
 
-int SnapuserdServer::Receivemsg(int fd) {
-    char msg[MAX_PACKET_SIZE];
-    std::unique_ptr<Client> newClient;
-    int ret = 0;
-
-    while (1) {
-        memset(msg, '\0', MAX_PACKET_SIZE);
-        std::string str = Recvmsg(fd, &ret);
-
-        if (ret <= 0) {
-            LOG(DEBUG) << "recv failed with ret: " << ret;
-            return 0;
-        }
-
-        const char delim = ',';
-
-        std::vector<std::string> out;
-        Parsemsg(str, delim, out);
-        DaemonOperations op = Resolveop(out[0]);
-        memset(msg, '\0', MAX_PACKET_SIZE);
-
-        switch (op) {
-            case DaemonOperations::START: {
-                // Message format:
-                // start,<cow_device_path>,<source_device_path>,<control_device>
-                //
-                // Start the new thread which binds to dm-user misc device
-                newClient = std::make_unique<Client>();
-                newClient->SetThreadHandler(
-                        std::bind(&SnapuserdServer::ThreadStart, this, out[1], out[2], out[3]));
-                clients_vec_.push_back(std::move(newClient));
-                sprintf(msg, "success");
-                Sendmsg(fd, msg, MAX_PACKET_SIZE);
-                return 0;
-            }
-            case DaemonOperations::STOP: {
-                // Message format: stop
-                //
-                // Stop all the threads gracefully and then shutdown the
-                // main thread
-                ShutdownThreads();
-                return static_cast<int>(DaemonOperations::STOP);
-            }
-            case DaemonOperations::TERMINATING: {
-                // Message format: terminate-request
-                //
-                // This is invoked during transition. First stage
-                // daemon will receive this request. First stage daemon
-                // will be considered as a passive daemon from hereon.
-                SetTerminating();
-                sprintf(msg, "success");
-                Sendmsg(fd, msg, MAX_PACKET_SIZE);
-                return 0;
-            }
-            case DaemonOperations::QUERY: {
-                // Message format: query
-                //
-                // As part of transition, Second stage daemon will be
-                // created before terminating the first stage daemon. Hence,
-                // for a brief period client may have to distiguish between
-                // first stage daemon and second stage daemon.
-                //
-                // Second stage daemon is marked as active and hence will
-                // be ready to receive control message.
-                std::string dstr = GetDaemonStatus();
-                memcpy(msg, dstr.c_str(), dstr.size());
-                Sendmsg(fd, msg, MAX_PACKET_SIZE);
-                if (dstr == "active")
-                    break;
-                else
-                    return 0;
-            }
-            default: {
-                sprintf(msg, "fail");
-                Sendmsg(fd, msg, MAX_PACKET_SIZE);
-                return 0;
-            }
-        }
-    }
-}
-
-int SnapuserdServer::Start(std::string socketname) {
-    sockfd_.reset(socket_local_server(socketname.c_str(), ANDROID_SOCKET_NAMESPACE_RESERVED,
-                                      SOCK_STREAM));
-    if (sockfd_ < 0) {
-        PLOG(ERROR) << "Failed to create server socket " << socketname;
-        return -1;
-    }
+    AddWatchedFd(sockfd_);
 
     LOG(DEBUG) << "Snapuserd server successfully started with socket name " << socketname;
-    return 0;
+    return true;
 }
 
-int SnapuserdServer::AcceptClient() {
-    int fd = accept(sockfd_.get(), NULL, NULL);
+bool SnapuserdServer::Run() {
+    while (!IsTerminating()) {
+        int rv = TEMP_FAILURE_RETRY(poll(watched_fds_.data(), watched_fds_.size(), -1));
+        if (rv < 0) {
+            PLOG(ERROR) << "poll failed";
+            return false;
+        }
+        if (!rv) {
+            continue;
+        }
+
+        if (watched_fds_[0].revents) {
+            AcceptClient();
+        }
+
+        auto iter = watched_fds_.begin() + 1;
+        while (iter != watched_fds_.end()) {
+            if (iter->revents && !HandleClient(iter->fd, iter->revents)) {
+                close(iter->fd);
+                iter = watched_fds_.erase(iter);
+            } else {
+                iter++;
+            }
+        }
+    }
+    return true;
+}
+
+void SnapuserdServer::AddWatchedFd(android::base::borrowed_fd fd) {
+    struct pollfd p = {};
+    p.fd = fd.get();
+    p.events = POLLIN;
+    watched_fds_.emplace_back(std::move(p));
+}
+
+void SnapuserdServer::AcceptClient() {
+    int fd = TEMP_FAILURE_RETRY(accept4(sockfd_.get(), nullptr, nullptr, SOCK_CLOEXEC));
     if (fd < 0) {
-        PLOG(ERROR) << "Socket accept failed: " << strerror(errno);
-        return -1;
+        PLOG(ERROR) << "accept4 failed";
+        return;
     }
 
-    return Receivemsg(fd);
+    AddWatchedFd(fd);
+}
+
+bool SnapuserdServer::HandleClient(android::base::borrowed_fd fd, int revents) {
+    if (revents & POLLHUP) {
+        LOG(DEBUG) << "Snapuserd client disconnected";
+        return false;
+    }
+
+    std::string str;
+    if (!Recv(fd, &str)) {
+        return false;
+    }
+    if (!Receivemsg(fd, str)) {
+        LOG(ERROR) << "Encountered error handling client message, revents: " << revents;
+        return false;
+    }
+    return true;
+}
+
+void SnapuserdServer::Interrupt() {
+    // Force close the socket so poll() fails.
+    sockfd_ = {};
+    SetTerminating();
 }
 
 }  // namespace snapshot