snapuserd: Refactor client to allow persistent connections.

This refactors SnapuserdClient so it retains a connection for its
lifetime. This allows SnapshotManager to ensure the daemon is running
and hold a connection open across all of its operations.

The main impetus of this change is to remove the ambiguity between first
and second-stage sockets. SnapshotManager should only ever connect to
the first-stage socket during first-stage init, or, to initiate the
"transition" step during second-stage init.

The transition steps are roughly:
 (1) Start second-stage daemon.
 (2) Load new device-mapper tables.
 (3) Connect second-stage daemon to new dm-user devices.
 (4) Activate the new tables, flushing IO to the first-stage daemon.
 (5) Send a signal to the first-stage daemon to exit.

This patch makes it easier to hold these two separate connections.

Bug: 168554689
Test: manual test
Change-Id: I51cb9adecffb19143ed685e0c33456177ec3d81f
diff --git a/fs_mgr/libsnapshot/cow_snapuserd_test.cpp b/fs_mgr/libsnapshot/cow_snapuserd_test.cpp
index 1d6c104..4eab9a5 100644
--- a/fs_mgr/libsnapshot/cow_snapuserd_test.cpp
+++ b/fs_mgr/libsnapshot/cow_snapuserd_test.cpp
@@ -65,7 +65,7 @@
         product_a_ = std::make_unique<TemporaryFile>(path);
         ASSERT_GE(product_a_->fd, 0) << strerror(errno);
 
-        size_ = 100_MiB;
+        size_ = 1_MiB;
     }
 
     void TearDown() override {
@@ -123,7 +123,7 @@
     }
 
     void TestIO(unique_fd& snapshot_fd, std::unique_ptr<uint8_t[]>& buffer);
-    SnapuserdClient client_;
+    std::unique_ptr<SnapuserdClient> client_;
 };
 
 void SnapuserdTest::Init() {
@@ -151,12 +151,12 @@
         offset += 1_MiB;
     }
 
-    for (size_t j = 0; j < (800_MiB / 1_MiB); j++) {
+    for (size_t j = 0; j < (8_MiB / 1_MiB); j++) {
         ASSERT_EQ(ReadFullyAtOffset(rnd_fd, (char*)random_buffer.get(), 1_MiB, 0), true);
         ASSERT_EQ(android::base::WriteFully(system_a_->fd, random_buffer.get(), 1_MiB), true);
     }
 
-    for (size_t j = 0; j < (800_MiB / 1_MiB); j++) {
+    for (size_t j = 0; j < (8_MiB / 1_MiB); j++) {
         ASSERT_EQ(ReadFullyAtOffset(rnd_fd, (char*)random_buffer.get(), 1_MiB, 0), true);
         ASSERT_EQ(android::base::WriteFully(product_a_->fd, random_buffer.get(), 1_MiB), true);
     }
@@ -297,18 +297,18 @@
 }
 
 void SnapuserdTest::StartSnapuserdDaemon() {
-    int ret;
+    ASSERT_TRUE(EnsureSnapuserdStarted());
 
-    ret = client_.StartSnapuserd();
-    ASSERT_EQ(ret, 0);
+    client_ = SnapuserdClient::Connect(kSnapuserdSocket, 5s);
+    ASSERT_NE(client_, nullptr);
 
-    ret = client_.InitializeSnapuserd(cow_system_->path, system_a_loop_->device(),
-                                      GetSystemControlPath());
-    ASSERT_EQ(ret, 0);
+    bool ok = client_->InitializeSnapuserd(cow_system_->path, system_a_loop_->device(),
+                                           GetSystemControlPath());
+    ASSERT_TRUE(ok);
 
-    ret = client_.InitializeSnapuserd(cow_product_->path, product_a_loop_->device(),
+    ok = client_->InitializeSnapuserd(cow_product_->path, product_a_loop_->device(),
                                       GetProductControlPath());
-    ASSERT_EQ(ret, 0);
+    ASSERT_TRUE(ok);
 }
 
 void SnapuserdTest::CreateSnapshotDevices() {
@@ -464,10 +464,6 @@
             {cow_system_1_->path, system_a_loop_->device(), GetSystemControlPath()},
             {cow_product_1_->path, product_a_loop_->device(), GetProductControlPath()}};
 
-    // Start the second stage deamon and send the devices information through
-    // vector.
-    ASSERT_EQ(client_.RestartSnapuserd(vec), 0);
-
     // TODO: This is not switching snapshot device but creates a new table;
     // Second stage daemon will be ready to serve the IO request. From now
     // onwards, we can go ahead and shutdown the first stage daemon
@@ -476,9 +472,6 @@
     DeleteDmUser(cow_system_, "system-snapshot");
     DeleteDmUser(cow_product_, "product-snapshot");
 
-    // Stop the first stage daemon
-    ASSERT_EQ(client_.StopSnapuserd(true), 0);
-
     // Test the IO again with the second stage daemon
     snapshot_fd.reset(open("/dev/block/mapper/system-snapshot-1", O_RDONLY));
     ASSERT_TRUE(snapshot_fd > 0);
@@ -494,7 +487,7 @@
     DeleteDmUser(cow_product_1_, "product-snapshot-1");
 
     // Stop the second stage daemon
-    ASSERT_EQ(client_.StopSnapuserd(false), 0);
+    ASSERT_TRUE(client_->StopSnapuserd());
 }
 
 }  // namespace snapshot
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h
index d6713b8..dffd481 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h
@@ -14,49 +14,45 @@
 
 #pragma once
 
+#include <chrono>
 #include <cstring>
 #include <iostream>
 #include <string>
 #include <thread>
 #include <vector>
 
+#include <android-base/unique_fd.h>
+
 namespace android {
 namespace snapshot {
 
 static constexpr uint32_t PACKET_SIZE = 512;
-static constexpr uint32_t MAX_CONNECT_RETRY_COUNT = 10;
 
 static constexpr char kSnapuserdSocketFirstStage[] = "snapuserd_first_stage";
 static constexpr char kSnapuserdSocket[] = "snapuserd";
 
+// Ensure that the second-stage daemon for snapuserd is running.
+bool EnsureSnapuserdStarted();
+
 class SnapuserdClient {
   private:
-    int sockfd_ = 0;
+    android::base::unique_fd sockfd_;
 
-    int Sendmsg(const char* msg, size_t size);
+    bool Sendmsg(const std::string& msg);
     std::string Receivemsg();
-    int StartSnapuserdaemon(std::string socketname);
-    bool ConnectToServerSocket(std::string socketname);
-    bool ConnectToServer();
 
-    void DisconnectFromServer() { close(sockfd_); }
-
-    std::string GetSocketNameFirstStage() {
-        static std::string snapd_one("snapdone");
-        return snapd_one;
-    }
-
-    std::string GetSocketNameSecondStage() {
-        static std::string snapd_two("snapdtwo");
-        return snapd_two;
-    }
+    bool ValidateConnection();
 
   public:
-    int StartSnapuserd();
-    int StopSnapuserd(bool firstStageDaemon);
+    explicit SnapuserdClient(android::base::unique_fd&& sockfd);
+
+    static std::unique_ptr<SnapuserdClient> Connect(const std::string& socket_name,
+                                                    std::chrono::milliseconds timeout_ms);
+
+    bool StopSnapuserd();
     int RestartSnapuserd(std::vector<std::vector<std::string>>& vec);
-    int InitializeSnapuserd(std::string cow_device, std::string backing_device,
-                            std::string control_device);
+    bool InitializeSnapuserd(const std::string& cow_device, const std::string& backing_device,
+                             const std::string& control_device);
 };
 
 }  // namespace snapshot
diff --git a/fs_mgr/libsnapshot/snapuserd_client.cpp b/fs_mgr/libsnapshot/snapuserd_client.cpp
index 78dbada..532e585 100644
--- a/fs_mgr/libsnapshot/snapuserd_client.cpp
+++ b/fs_mgr/libsnapshot/snapuserd_client.cpp
@@ -29,72 +29,98 @@
 #include <chrono>
 
 #include <android-base/logging.h>
+#include <android-base/properties.h>
 #include <libsnapshot/snapuserd_client.h>
 
 namespace android {
 namespace snapshot {
 
-bool SnapuserdClient::ConnectToServerSocket(std::string socketname) {
-    sockfd_ = 0;
+using namespace std::chrono_literals;
+using android::base::unique_fd;
 
-    sockfd_ =
-            socket_local_client(socketname.c_str(), ANDROID_SOCKET_NAMESPACE_RESERVED, SOCK_STREAM);
-    if (sockfd_ < 0) {
-        LOG(ERROR) << "Failed to connect to " << socketname;
-        return false;
+bool EnsureSnapuserdStarted() {
+    if (android::base::GetProperty("init.svc.snapuserd", "") == "running") {
+        return true;
     }
 
-    std::string msg = "query";
+    android::base::SetProperty("ctl.start", "snapuserd");
+    if (!android::base::WaitForProperty("init.svc.snapuserd", "running", 10s)) {
+        LOG(ERROR) << "Timed out waiting for snapuserd to start.";
+        return false;
+    }
+    return true;
+}
 
-    int sendRet = Sendmsg(msg.c_str(), msg.size());
-    if (sendRet < 0) {
-        LOG(ERROR) << "Failed to send query message to snapuserd daemon with socket " << socketname;
-        DisconnectFromServer();
+SnapuserdClient::SnapuserdClient(android::base::unique_fd&& sockfd) : sockfd_(std::move(sockfd)) {}
+
+static inline bool IsRetryErrno() {
+    return errno == ECONNREFUSED || errno == EINTR;
+}
+
+std::unique_ptr<SnapuserdClient> SnapuserdClient::Connect(const std::string& socket_name,
+                                                          std::chrono::milliseconds timeout_ms) {
+    unique_fd fd;
+    auto start = std::chrono::steady_clock::now();
+    while (true) {
+        fd.reset(socket_local_client(socket_name.c_str(), ANDROID_SOCKET_NAMESPACE_RESERVED,
+                                     SOCK_STREAM));
+        if (fd >= 0) break;
+        if (fd < 0 && !IsRetryErrno()) {
+            PLOG(ERROR) << "connect failed: " << socket_name;
+            return nullptr;
+        }
+
+        auto now = std::chrono::steady_clock::now();
+        auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(now - start);
+        if (elapsed >= timeout_ms) {
+            LOG(ERROR) << "Timed out connecting to snapuserd socket: " << socket_name;
+            return nullptr;
+        }
+
+        std::this_thread::sleep_for(100ms);
+    }
+
+    auto client = std::make_unique<SnapuserdClient>(std::move(fd));
+    if (!client->ValidateConnection()) {
+        return nullptr;
+    }
+    return client;
+}
+
+bool SnapuserdClient::ValidateConnection() {
+    if (!Sendmsg("query")) {
         return false;
     }
 
     std::string str = Receivemsg();
 
-    if (str.find("fail") != std::string::npos) {
-        LOG(ERROR) << "Failed to receive message from snapuserd daemon with socket " << socketname;
-        DisconnectFromServer();
-        return false;
-    }
-
     // If the daemon is passive then fallback to secondary active daemon. Daemon
     // is passive during transition phase. Please see RestartSnapuserd()
     if (str.find("passive") != std::string::npos) {
-        LOG(DEBUG) << "Snapuserd is passive with socket " << socketname;
-        DisconnectFromServer();
+        LOG(ERROR) << "Snapuserd is terminating";
         return false;
     }
 
-    CHECK(str.find("active") != std::string::npos);
-
+    if (str != "active") {
+        LOG(ERROR) << "Received failure querying daemon";
+        return false;
+    }
     return true;
 }
 
-bool SnapuserdClient::ConnectToServer() {
-    if (ConnectToServerSocket(GetSocketNameFirstStage())) return true;
-
-    if (ConnectToServerSocket(GetSocketNameSecondStage())) return true;
-
-    return false;
-}
-
-int SnapuserdClient::Sendmsg(const char* msg, size_t size) {
-    int numBytesSent = TEMP_FAILURE_RETRY(send(sockfd_, msg, size, 0));
+bool SnapuserdClient::Sendmsg(const std::string& msg) {
+    ssize_t numBytesSent = TEMP_FAILURE_RETRY(send(sockfd_, msg.data(), msg.size(), 0));
     if (numBytesSent < 0) {
-        LOG(ERROR) << "Send failed " << strerror(errno);
-        return -1;
+        PLOG(ERROR) << "Send failed";
+        return false;
     }
 
-    if ((uint)numBytesSent < size) {
-        LOG(ERROR) << "Partial data sent " << strerror(errno);
-        return -1;
+    if ((size_t)numBytesSent < msg.size()) {
+        LOG(ERROR) << "Partial data sent, expected " << msg.size() << " bytes, sent "
+                   << numBytesSent;
+        return false;
     }
-
-    return 0;
+    return true;
 }
 
 std::string SnapuserdClient::Receivemsg() {
@@ -127,98 +153,33 @@
     return msgStr;
 }
 
-int SnapuserdClient::StopSnapuserd(bool firstStageDaemon) {
-    if (firstStageDaemon) {
-        sockfd_ = socket_local_client(GetSocketNameFirstStage().c_str(),
-                                      ANDROID_SOCKET_NAMESPACE_RESERVED, SOCK_STREAM);
-        if (sockfd_ < 0) {
-            LOG(ERROR) << "Failed to connect to " << GetSocketNameFirstStage();
-            return -1;
-        }
-    } else {
-        if (!ConnectToServer()) {
-            LOG(ERROR) << "Failed to connect to socket " << GetSocketNameSecondStage();
-            return -1;
-        }
-    }
-
-    std::string msg = "stop";
-
-    int sendRet = Sendmsg(msg.c_str(), msg.size());
-    if (sendRet < 0) {
+bool SnapuserdClient::StopSnapuserd() {
+    if (!Sendmsg("stop")) {
         LOG(ERROR) << "Failed to send stop message to snapuserd daemon";
-        return -1;
+        return false;
     }
 
-    DisconnectFromServer();
-
-    return 0;
+    sockfd_ = {};
+    return true;
 }
 
-int SnapuserdClient::StartSnapuserdaemon(std::string socketname) {
-    int retry_count = 0;
-
-    if (fork() == 0) {
-        const char* argv[] = {"/system/bin/snapuserd", socketname.c_str(), nullptr};
-        if (execv(argv[0], const_cast<char**>(argv))) {
-            LOG(ERROR) << "Failed to exec snapuserd daemon";
-            return -1;
-        }
-    }
-
-    // snapuserd is a daemon and will never exit; parent can't wait here
-    // to get the return code. Since Snapuserd starts the socket server,
-    // give it some time to fully launch.
-    //
-    // Try to connect to server to verify snapuserd server is started
-    while (retry_count < MAX_CONNECT_RETRY_COUNT) {
-        if (!ConnectToServer()) {
-            retry_count++;
-            std::this_thread::sleep_for(std::chrono::milliseconds(500));
-        } else {
-            close(sockfd_);
-            return 0;
-        }
-    }
-
-    LOG(ERROR) << "Failed to start snapuserd daemon";
-    return -1;
-}
-
-int SnapuserdClient::StartSnapuserd() {
-    if (StartSnapuserdaemon(GetSocketNameFirstStage()) < 0) return -1;
-
-    return 0;
-}
-
-int SnapuserdClient::InitializeSnapuserd(std::string cow_device, std::string backing_device,
-                                         std::string control_device) {
-    int ret = 0;
-
-    if (!ConnectToServer()) {
-        LOG(ERROR) << "Failed to connect to server ";
-        return -1;
-    }
-
+bool SnapuserdClient::InitializeSnapuserd(const std::string& cow_device,
+                                          const std::string& backing_device,
+                                          const std::string& control_device) {
     std::string msg = "start," + cow_device + "," + backing_device + "," + control_device;
-
-    ret = Sendmsg(msg.c_str(), msg.size());
-    if (ret < 0) {
+    if (!Sendmsg(msg)) {
         LOG(ERROR) << "Failed to send message " << msg << " to snapuserd daemon";
-        return -1;
+        return false;
     }
 
     std::string str = Receivemsg();
-
-    if (str.find("fail") != std::string::npos) {
+    if (str != "success") {
         LOG(ERROR) << "Failed to receive ack for " << msg << " from snapuserd daemon";
-        return -1;
+        return false;
     }
 
-    DisconnectFromServer();
-
     LOG(DEBUG) << "Snapuserd daemon initialized with " << msg;
-    return 0;
+    return true;
 }
 
 /*
@@ -254,18 +215,8 @@
  *
  */
 int SnapuserdClient::RestartSnapuserd(std::vector<std::vector<std::string>>& vec) {
-    // Connect to first-stage daemon and send a terminate-request control
-    // message. This will not terminate the daemon but will mark the daemon as
-    // passive.
-    if (!ConnectToServer()) {
-        LOG(ERROR) << "Failed to connect to server ";
-        return -1;
-    }
-
     std::string msg = "terminate-request";
-
-    int sendRet = Sendmsg(msg.c_str(), msg.size());
-    if (sendRet < 0) {
+    if (!Sendmsg(msg)) {
         LOG(ERROR) << "Failed to send message " << msg << " to snapuserd daemon";
         return -1;
     }
@@ -279,16 +230,13 @@
 
     CHECK(str.find("success") != std::string::npos);
 
-    DisconnectFromServer();
-
     // Start the new daemon
-    if (StartSnapuserdaemon(GetSocketNameSecondStage()) < 0) {
-        LOG(ERROR) << "Failed to start new daemon at socket " << GetSocketNameSecondStage();
+    if (!EnsureSnapuserdStarted()) {
+        LOG(ERROR) << "Failed to start new daemon";
         return -1;
     }
 
-    LOG(DEBUG) << "Second stage Snapuserd daemon created successfully at socket "
-               << GetSocketNameSecondStage();
+    LOG(DEBUG) << "Second stage Snapuserd daemon created successfully";
 
     // Vector contains all the device information to be passed to the new
     // daemon. Note that the caller can choose to initialize separately