Merge "libsnapshot:snapuserd: Handle signals"
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd.h b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd.h
index 6331edb..d495014 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd.h
@@ -65,7 +65,7 @@
           backing_store_device_(in_backing_store_device),
           metadata_read_done_(false) {}
 
-    int Init();
+    bool Init();
     int Run();
     int ReadDmUserHeader();
     int WriteDmUserPayload(size_t size);
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h
index 2d9d729..535e923 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h
@@ -14,22 +14,8 @@
 
 #pragma once
 
-#include <arpa/inet.h>
-#include <cutils/sockets.h>
-#include <errno.h>
-#include <netdb.h>
-#include <netinet/in.h>
-#include <stdint.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <sys/socket.h>
-#include <sys/types.h>
-#include <unistd.h>
-
-#include <chrono>
 #include <cstring>
 #include <iostream>
-#include <sstream>
 #include <string>
 #include <thread>
 #include <vector>
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_daemon.h b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_daemon.h
index c0d3c5e..94542d7 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_daemon.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_daemon.h
@@ -14,6 +14,8 @@
 
 #pragma once
 
+#include <poll.h>
+
 #include <libsnapshot/snapuserd_server.h>
 
 namespace android {
@@ -34,12 +36,17 @@
 
   private:
     bool is_running_;
+    std::unique_ptr<struct pollfd> poll_fd_;
+    // Signal mask used with ppoll()
+    sigset_t signal_mask_;
 
     Daemon();
     Daemon(Daemon const&) = delete;
     void operator=(Daemon const&) = delete;
 
     SnapuserdServer server_;
+    void MaskAllSignalsExceptIntAndTerm();
+    void MaskAllSignals();
     static void SignalHandler(int signal);
 };
 
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_server.h b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_server.h
index 79b883a..584fe71 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_server.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_server.h
@@ -14,18 +14,6 @@
 
 #pragma once
 
-#include <stdint.h>
-
-#include <arpa/inet.h>
-#include <cutils/sockets.h>
-#include <netinet/in.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <sys/socket.h>
-#include <sys/types.h>
-#include <unistd.h>
-
-#include <errno.h>
 #include <cstdio>
 #include <cstring>
 #include <functional>
@@ -89,6 +77,7 @@
     android::base::unique_fd sockfd_;
     bool terminating_;
     std::vector<std::unique_ptr<Client>> clients_vec_;
+
     void ThreadStart(std::string cow_device, std::string backing_device) override;
     void ShutdownThreads();
     DaemonOperations Resolveop(std::string& input);
@@ -100,8 +89,6 @@
     bool IsTerminating() { return terminating_; }
 
   public:
-    ~SnapuserdServer() { clients_vec_.clear(); }
-
     SnapuserdServer() { terminating_ = false; }
 
     int Start(std::string socketname);
@@ -109,6 +96,7 @@
     int Receivemsg(int fd);
     int Sendmsg(int fd, char* msg, size_t len);
     std::string Recvmsg(int fd, int* ret);
+    android::base::borrowed_fd GetSocketFd() { return sockfd_; }
 };
 
 }  // namespace snapshot
diff --git a/fs_mgr/libsnapshot/snapuserd.cpp b/fs_mgr/libsnapshot/snapuserd.cpp
index 34481b7..3ed853f 100644
--- a/fs_mgr/libsnapshot/snapuserd.cpp
+++ b/fs_mgr/libsnapshot/snapuserd.cpp
@@ -485,17 +485,17 @@
     return sizeof(struct dm_user_header) + size;
 }
 
-int Snapuserd::Init() {
+bool Snapuserd::Init() {
     backing_store_fd_.reset(open(backing_store_device_.c_str(), O_RDONLY));
     if (backing_store_fd_ < 0) {
         LOG(ERROR) << "Open Failed: " << backing_store_device_;
-        return 1;
+        return false;
     }
 
     cow_fd_.reset(open(cow_device_.c_str(), O_RDWR));
     if (cow_fd_ < 0) {
         LOG(ERROR) << "Open Failed: " << cow_device_;
-        return 1;
+        return false;
     }
 
     std::string str(cow_device_);
@@ -509,7 +509,7 @@
     std::string uuid;
     if (!dm.GetDmDeviceUuidByName(device_name, &uuid)) {
         LOG(ERROR) << "Unable to find UUID for " << cow_device_;
-        return 1;
+        return false;
     }
 
     LOG(DEBUG) << "UUID: " << uuid;
@@ -518,7 +518,7 @@
     ctrl_fd_.reset(open(t.control_path().c_str(), O_RDWR));
     if (ctrl_fd_ < 0) {
         LOG(ERROR) << "Unable to open " << t.control_path();
-        return 1;
+        return false;
     }
 
     // Allocate the buffer which is used to communicate between
@@ -528,7 +528,7 @@
     size_t buf_size = sizeof(struct dm_user_header) + PAYLOAD_SIZE;
     bufsink_.Initialize(buf_size);
 
-    return 0;
+    return true;
 }
 
 int Snapuserd::Run() {
@@ -601,6 +601,11 @@
                         ret = ReadData(chunk + num_chunks_read, read_size);
                         if (ret < 0) {
                             LOG(ERROR) << "ReadData failed";
+                            // TODO: Bug 168259959: All the error paths from this function
+                            // should send error code to dm-user thereby IO
+                            // terminates with an error from dm-user. Returning
+                            // here without sending error code will block the
+                            // IO.
                             return ret;
                         }
                     }
@@ -622,7 +627,7 @@
         }
 
         case DM_USER_MAP_WRITE: {
-            // TODO: After merge operation is completed, kernel issues write
+            // TODO: Bug: 168311203: After merge operation is completed, kernel issues write
             // to flush all the exception mappings where the merge is
             // completed. If dm-user routes the WRITE IO, we need to clear
             // in-memory data structures representing those exception
diff --git a/fs_mgr/libsnapshot/snapuserd_client.cpp b/fs_mgr/libsnapshot/snapuserd_client.cpp
index b10de35..bef8f5c 100644
--- a/fs_mgr/libsnapshot/snapuserd_client.cpp
+++ b/fs_mgr/libsnapshot/snapuserd_client.cpp
@@ -1,3 +1,33 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <arpa/inet.h>
+#include <cutils/sockets.h>
+#include <errno.h>
+#include <netdb.h>
+#include <netinet/in.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <chrono>
+
 #include <android-base/logging.h>
 #include <libsnapshot/snapuserd_client.h>
 
@@ -68,6 +98,37 @@
 }
 
 std::string SnapuserdClient::Receivemsg() {
+    int ret;
+    struct timeval tv;
+    fd_set set;
+    char msg[PACKET_SIZE];
+    std::string msgStr("fail");
+
+    tv.tv_sec = 2;
+    tv.tv_usec = 0;
+    FD_ZERO(&set);
+    FD_SET(sockfd_, &set);
+    ret = select(sockfd_ + 1, &set, NULL, NULL, &tv);
+    if (ret == -1) {  // select failed
+        LOG(ERROR) << "Snapuserd:client: Select call failed";
+    } else if (ret == 0) {  // timeout
+        LOG(ERROR) << "Snapuserd:client: Select call timeout";
+    } else {
+        ret = TEMP_FAILURE_RETRY(recv(sockfd_, msg, PACKET_SIZE, 0));
+        if (ret < 0) {
+            PLOG(ERROR) << "Snapuserd:client: recv failed";
+        } else if (ret == 0) {
+            LOG(DEBUG) << "Snapuserd:client disconnected";
+        } else {
+            msgStr.clear();
+            msgStr = msg;
+        }
+    }
+    return msgStr;
+}
+
+#if 0
+std::string SnapuserdClient::Receivemsg() {
     char msg[PACKET_SIZE];
     std::string msgStr("fail");
     int ret;
@@ -82,6 +143,7 @@
     msgStr = msg;
     return msgStr;
 }
+#endif
 
 int SnapuserdClient::StopSnapuserd(bool firstStageDaemon) {
     if (firstStageDaemon) {
diff --git a/fs_mgr/libsnapshot/snapuserd_daemon.cpp b/fs_mgr/libsnapshot/snapuserd_daemon.cpp
index c1008b9..8e76618 100644
--- a/fs_mgr/libsnapshot/snapuserd_daemon.cpp
+++ b/fs_mgr/libsnapshot/snapuserd_daemon.cpp
@@ -1,3 +1,19 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 #include <android-base/logging.h>
 #include <libsnapshot/snapuserd_daemon.h>
 
@@ -16,11 +32,27 @@
     return ret;
 }
 
+void Daemon::MaskAllSignalsExceptIntAndTerm() {
+    sigset_t signal_mask;
+    sigfillset(&signal_mask);
+    sigdelset(&signal_mask, SIGINT);
+    sigdelset(&signal_mask, SIGTERM);
+    sigdelset(&signal_mask, SIGPIPE);
+    if (sigprocmask(SIG_SETMASK, &signal_mask, NULL) != 0) {
+        PLOG(ERROR) << "Failed to set sigprocmask";
+    }
+}
+
+void Daemon::MaskAllSignals() {
+    sigset_t signal_mask;
+    sigfillset(&signal_mask);
+    if (sigprocmask(SIG_SETMASK, &signal_mask, NULL) != 0) {
+        PLOG(ERROR) << "Couldn't mask all signals";
+    }
+}
+
 Daemon::Daemon() {
     is_running_ = true;
-    // TODO: Mask other signals - Bug 168258493
-    signal(SIGINT, Daemon::SignalHandler);
-    signal(SIGTERM, Daemon::SignalHandler);
 }
 
 bool Daemon::IsRunning() {
@@ -28,10 +60,41 @@
 }
 
 void Daemon::Run() {
+    poll_fd_ = std::make_unique<struct pollfd>();
+    poll_fd_->fd = server_.GetSocketFd().get();
+    poll_fd_->events = POLLIN;
+
+    sigfillset(&signal_mask_);
+    sigdelset(&signal_mask_, SIGINT);
+    sigdelset(&signal_mask_, SIGTERM);
+
+    // Masking signals here ensure that after this point, we won't handle INT/TERM
+    // until after we call into ppoll()
+    MaskAllSignals();
+    signal(SIGINT, Daemon::SignalHandler);
+    signal(SIGTERM, Daemon::SignalHandler);
+    signal(SIGPIPE, Daemon::SignalHandler);
+
+    LOG(DEBUG) << "Snapuserd-server: ready to accept connections";
+
     while (IsRunning()) {
-        if (server_.AcceptClient() == static_cast<int>(DaemonOperations::STOP)) {
-            Daemon::Instance().is_running_ = false;
+        int ret = ppoll(poll_fd_.get(), 1, nullptr, &signal_mask_);
+        MaskAllSignalsExceptIntAndTerm();
+
+        if (ret == -1) {
+            PLOG(ERROR) << "Snapuserd:ppoll error";
+            break;
         }
+
+        if (poll_fd_->revents == POLLIN) {
+            if (server_.AcceptClient() == static_cast<int>(DaemonOperations::STOP)) {
+                Daemon::Instance().is_running_ = false;
+            }
+        }
+
+        // Mask all signals to ensure that is_running_ can't become false between
+        // checking it in the while condition and calling into ppoll()
+        MaskAllSignals();
     }
 }
 
@@ -43,6 +106,10 @@
             Daemon::Instance().is_running_ = false;
             break;
         }
+        case SIGPIPE: {
+            LOG(ERROR) << "Received SIGPIPE signal";
+            break;
+        }
         default:
             LOG(ERROR) << "Received unknown signal " << signal;
             break;
diff --git a/fs_mgr/libsnapshot/snapuserd_server.cpp b/fs_mgr/libsnapshot/snapuserd_server.cpp
index 1e8b642..1f8dd63 100644
--- a/fs_mgr/libsnapshot/snapuserd_server.cpp
+++ b/fs_mgr/libsnapshot/snapuserd_server.cpp
@@ -1,3 +1,30 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <arpa/inet.h>
+#include <cutils/sockets.h>
+#include <errno.h>
+#include <netinet/in.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/socket.h>
+#include <sys/types.h>
+#include <unistd.h>
+
 #include <android-base/logging.h>
 #include <libsnapshot/snapuserd.h>
 #include <libsnapshot/snapuserd_server.h>
@@ -38,9 +65,9 @@
 // new thread
 void SnapuserdServer::ThreadStart(std::string cow_device, std::string backing_device) {
     Snapuserd snapd(cow_device, backing_device);
-    if (snapd.Init()) {
+    if (!snapd.Init()) {
         PLOG(ERROR) << "Snapuserd: Init failed";
-        exit(EXIT_FAILURE);
+        return;
     }
 
     while (StopRequested() == false) {