Merge "snapuserd: Allow connecting to the first-stage daemon."
diff --git a/fs_mgr/file_wait.cpp b/fs_mgr/file_wait.cpp
index cbf6845..af0699b 100644
--- a/fs_mgr/file_wait.cpp
+++ b/fs_mgr/file_wait.cpp
@@ -206,6 +206,9 @@
 }
 
 int64_t OneShotInotify::RemainingMs() const {
+    if (relative_timeout_ == std::chrono::milliseconds::max()) {
+        return std::chrono::milliseconds::max().count();
+    }
     auto remaining = (std::chrono::steady_clock::now() - start_time_);
     auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(remaining);
     return (relative_timeout_ - elapsed).count();
diff --git a/fs_mgr/include/fs_mgr/file_wait.h b/fs_mgr/include/fs_mgr/file_wait.h
index 74d160e..294e727 100644
--- a/fs_mgr/include/fs_mgr/file_wait.h
+++ b/fs_mgr/include/fs_mgr/file_wait.h
@@ -23,6 +23,9 @@
 // Wait at most |relative_timeout| milliseconds for |path| to exist. dirname(path)
 // must already exist. For example, to wait on /dev/block/dm-6, /dev/block must
 // be a valid directory.
+//
+// If relative_timeout is std::chrono::milliseconds::max(), then the wait will
+// block indefinitely.
 bool WaitForFile(const std::string& path, const std::chrono::milliseconds relative_timeout);
 
 // Wait at most |relative_timeout| milliseconds for |path| to stop existing.
diff --git a/fs_mgr/libsnapshot/snapuserd/Android.bp b/fs_mgr/libsnapshot/snapuserd/Android.bp
index 1739949..bc97afc 100644
--- a/fs_mgr/libsnapshot/snapuserd/Android.bp
+++ b/fs_mgr/libsnapshot/snapuserd/Android.bp
@@ -73,6 +73,7 @@
         "libbrotli",
         "libcutils_sockets",
         "libdm",
+        "libfs_mgr",
         "libgflags",
         "liblog",
         "libsnapshot_cow",
diff --git a/fs_mgr/libsnapshot/snapuserd/include/snapuserd/snapuserd_client.h b/fs_mgr/libsnapshot/snapuserd/include/snapuserd/snapuserd_client.h
index 280e857..aeecf41 100644
--- a/fs_mgr/libsnapshot/snapuserd/include/snapuserd/snapuserd_client.h
+++ b/fs_mgr/libsnapshot/snapuserd/include/snapuserd/snapuserd_client.h
@@ -31,6 +31,7 @@
 static constexpr uint32_t PACKET_SIZE = 512;
 
 static constexpr char kSnapuserdSocket[] = "snapuserd";
+static constexpr char kSnapuserdSocketProxy[] = "snapuserd_proxy";
 
 // Ensure that the second-stage daemon for snapuserd is running.
 bool EnsureSnapuserdStarted();
@@ -75,6 +76,9 @@
     // snapuserd to gracefully exit once all handler threads have terminated.
     // This should only be used on first-stage instances of snapuserd.
     bool DetachSnapuserd();
+
+    // Returns true if the snapuserd instance supports bridging a socket to second-stage init.
+    bool SupportsSecondStageSocketHandoff();
 };
 
 }  // namespace snapshot
diff --git a/fs_mgr/libsnapshot/snapuserd/snapuserd.rc b/fs_mgr/libsnapshot/snapuserd/snapuserd.rc
index 4bf34a2..02fda8d 100644
--- a/fs_mgr/libsnapshot/snapuserd/snapuserd.rc
+++ b/fs_mgr/libsnapshot/snapuserd/snapuserd.rc
@@ -5,3 +5,12 @@
     user root
     group root system
     seclabel u:r:snapuserd:s0
+
+service snapuserd_proxy /system/bin/snapuserd -socket-handoff
+    socket snapuserd stream 0660 system system
+    socket snapuserd_proxy seqpacket 0660 system root
+    oneshot
+    disabled
+    user root
+    group root system
+    seclabel u:r:snapuserd:s0
diff --git a/fs_mgr/libsnapshot/snapuserd/snapuserd_client.cpp b/fs_mgr/libsnapshot/snapuserd/snapuserd_client.cpp
index a72cdc9..81e9228 100644
--- a/fs_mgr/libsnapshot/snapuserd/snapuserd_client.cpp
+++ b/fs_mgr/libsnapshot/snapuserd/snapuserd_client.cpp
@@ -141,6 +141,16 @@
     return true;
 }
 
+bool SnapuserdClient::SupportsSecondStageSocketHandoff() {
+    std::string msg = "supports,second_stage_socket_handoff";
+    if (!Sendmsg(msg)) {
+        LOG(ERROR) << "Failed to send message " << msg << " to snapuserd";
+        return false;
+    }
+    std::string response = Receivemsg();
+    return response == "success";
+}
+
 std::string SnapuserdClient::Receivemsg() {
     char msg[PACKET_SIZE];
     ssize_t ret = TEMP_FAILURE_RETRY(recv(sockfd_, msg, sizeof(msg), 0));
diff --git a/fs_mgr/libsnapshot/snapuserd/snapuserd_daemon.cpp b/fs_mgr/libsnapshot/snapuserd/snapuserd_daemon.cpp
index 4152a64..e05822e 100644
--- a/fs_mgr/libsnapshot/snapuserd/snapuserd_daemon.cpp
+++ b/fs_mgr/libsnapshot/snapuserd/snapuserd_daemon.cpp
@@ -26,6 +26,8 @@
 DEFINE_string(socket, android::snapshot::kSnapuserdSocket, "Named socket or socket path.");
 DEFINE_bool(no_socket, false,
             "If true, no socket is used. Each additional argument is an INIT message.");
+DEFINE_bool(socket_handoff, false,
+            "If true, perform a socket hand-off with an existing snapuserd instance, then exit.");
 
 namespace android {
 namespace snapshot {
@@ -33,8 +35,28 @@
 bool Daemon::StartServer(int argc, char** argv) {
     int arg_start = gflags::ParseCommandLineFlags(&argc, &argv, true);
 
+    sigfillset(&signal_mask_);
+    sigdelset(&signal_mask_, SIGINT);
+    sigdelset(&signal_mask_, SIGTERM);
+    sigdelset(&signal_mask_, SIGUSR1);
+
+    // Masking signals here ensure that after this point, we won't handle INT/TERM
+    // until after we call into ppoll()
+    signal(SIGINT, Daemon::SignalHandler);
+    signal(SIGTERM, Daemon::SignalHandler);
+    signal(SIGPIPE, Daemon::SignalHandler);
+    signal(SIGUSR1, Daemon::SignalHandler);
+
+    MaskAllSignalsExceptIntAndTerm();
+
+    if (FLAGS_socket_handoff) {
+        return server_.RunForSocketHandoff();
+    }
     if (!FLAGS_no_socket) {
-        return server_.Start(FLAGS_socket);
+        if (!server_.Start(FLAGS_socket)) {
+            return false;
+        }
+        return server_.Run();
     }
 
     for (int i = arg_start; i < argc; i++) {
@@ -51,8 +73,7 @@
 
     // Skip the accept() call to avoid spurious log spam. The server will still
     // run until all handlers have completed.
-    server_.SetTerminating();
-    return true;
+    return server_.WaitForSocket();
 }
 
 void Daemon::MaskAllSignalsExceptIntAndTerm() {
@@ -61,6 +82,7 @@
     sigdelset(&signal_mask, SIGINT);
     sigdelset(&signal_mask, SIGTERM);
     sigdelset(&signal_mask, SIGPIPE);
+    sigdelset(&signal_mask, SIGUSR1);
     if (sigprocmask(SIG_SETMASK, &signal_mask, NULL) != 0) {
         PLOG(ERROR) << "Failed to set sigprocmask";
     }
@@ -74,28 +96,14 @@
     }
 }
 
-void Daemon::Run() {
-    sigfillset(&signal_mask_);
-    sigdelset(&signal_mask_, SIGINT);
-    sigdelset(&signal_mask_, SIGTERM);
-
-    // Masking signals here ensure that after this point, we won't handle INT/TERM
-    // until after we call into ppoll()
-    signal(SIGINT, Daemon::SignalHandler);
-    signal(SIGTERM, Daemon::SignalHandler);
-    signal(SIGPIPE, Daemon::SignalHandler);
-
-    LOG(DEBUG) << "Snapuserd-server: ready to accept connections";
-
-    MaskAllSignalsExceptIntAndTerm();
-
-    server_.Run();
-}
-
 void Daemon::Interrupt() {
     server_.Interrupt();
 }
 
+void Daemon::ReceivedSocketSignal() {
+    server_.ReceivedSocketSignal();
+}
+
 void Daemon::SignalHandler(int signal) {
     LOG(DEBUG) << "Snapuserd received signal: " << signal;
     switch (signal) {
@@ -108,6 +116,11 @@
             LOG(ERROR) << "Received SIGPIPE signal";
             break;
         }
+        case SIGUSR1: {
+            LOG(INFO) << "Received SIGUSR1, attaching to proxy socket";
+            Daemon::Instance().ReceivedSocketSignal();
+            break;
+        }
         default:
             LOG(ERROR) << "Received unknown signal " << signal;
             break;
@@ -126,7 +139,5 @@
         LOG(ERROR) << "Snapuserd daemon failed to start.";
         exit(EXIT_FAILURE);
     }
-    daemon.Run();
-
     return 0;
 }
diff --git a/fs_mgr/libsnapshot/snapuserd/snapuserd_daemon.h b/fs_mgr/libsnapshot/snapuserd/snapuserd_daemon.h
index f8afac5..b660ba2 100644
--- a/fs_mgr/libsnapshot/snapuserd/snapuserd_daemon.h
+++ b/fs_mgr/libsnapshot/snapuserd/snapuserd_daemon.h
@@ -36,8 +36,8 @@
     }
 
     bool StartServer(int argc, char** argv);
-    void Run();
     void Interrupt();
+    void ReceivedSocketSignal();
 
   private:
     // Signal mask used with ppoll()
diff --git a/fs_mgr/libsnapshot/snapuserd/snapuserd_server.cpp b/fs_mgr/libsnapshot/snapuserd/snapuserd_server.cpp
index 8339690..a29b19b 100644
--- a/fs_mgr/libsnapshot/snapuserd/snapuserd_server.cpp
+++ b/fs_mgr/libsnapshot/snapuserd/snapuserd_server.cpp
@@ -25,14 +25,26 @@
 #include <sys/types.h>
 #include <unistd.h>
 
+#include <android-base/cmsg.h>
 #include <android-base/logging.h>
-
+#include <android-base/properties.h>
+#include <android-base/scopeguard.h>
+#include <fs_mgr/file_wait.h>
+#include <snapuserd/snapuserd_client.h>
 #include "snapuserd.h"
 #include "snapuserd_server.h"
 
+#define _REALLY_INCLUDE_SYS__SYSTEM_PROPERTIES_H_
+#include <sys/_system_properties.h>
+
 namespace android {
 namespace snapshot {
 
+using namespace std::string_literals;
+
+using android::base::borrowed_fd;
+using android::base::unique_fd;
+
 DaemonOperations SnapuserdServer::Resolveop(std::string& input) {
     if (input == "init") return DaemonOperations::INIT;
     if (input == "start") return DaemonOperations::START;
@@ -40,6 +52,7 @@
     if (input == "query") return DaemonOperations::QUERY;
     if (input == "delete") return DaemonOperations::DELETE;
     if (input == "detach") return DaemonOperations::DETACH;
+    if (input == "supports") return DaemonOperations::SUPPORTS;
 
     return DaemonOperations::INVALID;
 }
@@ -193,6 +206,16 @@
             terminating_ = true;
             return true;
         }
+        case DaemonOperations::SUPPORTS: {
+            if (out.size() != 2) {
+                LOG(ERROR) << "Malformed supports message, " << out.size() << " parts";
+                return Sendmsg(fd, "fail");
+            }
+            if (out[1] == "second_stage_socket_handoff") {
+                return Sendmsg(fd, "success");
+            }
+            return Sendmsg(fd, "fail");
+        }
         default: {
             LOG(ERROR) << "Received unknown message type from client";
             Sendmsg(fd, "fail");
@@ -245,28 +268,36 @@
 }
 
 bool SnapuserdServer::Start(const std::string& socketname) {
+    bool start_listening = true;
+
     sockfd_.reset(android_get_control_socket(socketname.c_str()));
-    if (sockfd_ >= 0) {
-        if (listen(sockfd_.get(), 4) < 0) {
-            PLOG(ERROR) << "listen socket failed: " << socketname;
-            return false;
-        }
-    } else {
+    if (sockfd_ < 0) {
         sockfd_.reset(socket_local_server(socketname.c_str(), ANDROID_SOCKET_NAMESPACE_RESERVED,
                                           SOCK_STREAM));
         if (sockfd_ < 0) {
             PLOG(ERROR) << "Failed to create server socket " << socketname;
             return false;
         }
+        start_listening = false;
+    }
+    return StartWithSocket(start_listening);
+}
+
+bool SnapuserdServer::StartWithSocket(bool start_listening) {
+    if (start_listening && listen(sockfd_.get(), 4) < 0) {
+        PLOG(ERROR) << "listen socket failed";
+        return false;
     }
 
-    AddWatchedFd(sockfd_);
+    AddWatchedFd(sockfd_, POLLIN);
 
-    LOG(DEBUG) << "Snapuserd server successfully started with socket name " << socketname;
+    LOG(DEBUG) << "Snapuserd server now accepting connections";
     return true;
 }
 
 bool SnapuserdServer::Run() {
+    LOG(INFO) << "Now listening on snapuserd socket";
+
     while (!IsTerminating()) {
         int rv = TEMP_FAILURE_RETRY(poll(watched_fds_.data(), watched_fds_.size(), -1));
         if (rv < 0) {
@@ -311,10 +342,10 @@
     }
 }
 
-void SnapuserdServer::AddWatchedFd(android::base::borrowed_fd fd) {
+void SnapuserdServer::AddWatchedFd(android::base::borrowed_fd fd, int events) {
     struct pollfd p = {};
     p.fd = fd.get();
-    p.events = POLLIN;
+    p.events = events;
     watched_fds_.emplace_back(std::move(p));
 }
 
@@ -325,7 +356,7 @@
         return;
     }
 
-    AddWatchedFd(fd);
+    AddWatchedFd(fd, POLLIN);
 }
 
 bool SnapuserdServer::HandleClient(android::base::borrowed_fd fd, int revents) {
@@ -422,5 +453,97 @@
     return true;
 }
 
+bool SnapuserdServer::WaitForSocket() {
+    auto scope_guard = android::base::make_scope_guard([this]() -> void { JoinAllThreads(); });
+
+    auto socket_path = ANDROID_SOCKET_DIR "/"s + kSnapuserdSocketProxy;
+
+    if (!android::fs_mgr::WaitForFile(socket_path, std::chrono::milliseconds::max())) {
+        LOG(ERROR)
+                << "Failed to wait for proxy socket, second-stage snapuserd will fail to connect";
+        return false;
+    }
+
+    // We must re-initialize property service access, since we launched before
+    // second-stage init.
+    __system_properties_init();
+
+    if (!android::base::WaitForProperty("snapuserd.proxy_ready", "true")) {
+        LOG(ERROR)
+                << "Failed to wait for proxy property, second-stage snapuserd will fail to connect";
+        return false;
+    }
+
+    unique_fd fd(socket_local_client(kSnapuserdSocketProxy, ANDROID_SOCKET_NAMESPACE_RESERVED,
+                                     SOCK_SEQPACKET));
+    if (fd < 0) {
+        PLOG(ERROR) << "Failed to connect to socket proxy";
+        return false;
+    }
+
+    char code[1];
+    std::vector<unique_fd> fds;
+    ssize_t rv = android::base::ReceiveFileDescriptorVector(fd, code, sizeof(code), 1, &fds);
+    if (rv < 0) {
+        PLOG(ERROR) << "Failed to receive server socket over proxy";
+        return false;
+    }
+    if (fds.empty()) {
+        LOG(ERROR) << "Expected at least one file descriptor from proxy";
+        return false;
+    }
+
+    // We don't care if the ACK is received.
+    code[0] = 'a';
+    if (TEMP_FAILURE_RETRY(send(fd, code, sizeof(code), MSG_NOSIGNAL) < 0)) {
+        PLOG(ERROR) << "Failed to send ACK to proxy";
+        return false;
+    }
+
+    sockfd_ = std::move(fds[0]);
+    if (!StartWithSocket(true)) {
+        return false;
+    }
+    return Run();
+}
+
+bool SnapuserdServer::RunForSocketHandoff() {
+    unique_fd proxy_fd(android_get_control_socket(kSnapuserdSocketProxy));
+    if (proxy_fd < 0) {
+        PLOG(FATAL) << "Proxy could not get android control socket " << kSnapuserdSocketProxy;
+    }
+    borrowed_fd server_fd(android_get_control_socket(kSnapuserdSocket));
+    if (server_fd < 0) {
+        PLOG(FATAL) << "Proxy could not get android control socket " << kSnapuserdSocket;
+    }
+
+    if (listen(proxy_fd.get(), 4) < 0) {
+        PLOG(FATAL) << "Proxy listen socket failed";
+    }
+
+    if (!android::base::SetProperty("snapuserd.proxy_ready", "true")) {
+        LOG(FATAL) << "Proxy failed to set ready property";
+    }
+
+    unique_fd client_fd(
+            TEMP_FAILURE_RETRY(accept4(proxy_fd.get(), nullptr, nullptr, SOCK_CLOEXEC)));
+    if (client_fd < 0) {
+        PLOG(FATAL) << "Proxy accept failed";
+    }
+
+    char code[1] = {'a'};
+    std::vector<int> fds = {server_fd.get()};
+    ssize_t rv = android::base::SendFileDescriptorVector(client_fd, code, sizeof(code), fds);
+    if (rv < 0) {
+        PLOG(FATAL) << "Proxy could not send file descriptor to snapuserd";
+    }
+    // Wait for an ACK - results don't matter, we just don't want to risk closing
+    // the proxy socket too early.
+    if (recv(client_fd, code, sizeof(code), 0) < 0) {
+        PLOG(FATAL) << "Proxy could not receive terminating code from snapuserd";
+    }
+    return true;
+}
+
 }  // namespace snapshot
 }  // namespace android
diff --git a/fs_mgr/libsnapshot/snapuserd/snapuserd_server.h b/fs_mgr/libsnapshot/snapuserd/snapuserd_server.h
index 6699189..846f848 100644
--- a/fs_mgr/libsnapshot/snapuserd/snapuserd_server.h
+++ b/fs_mgr/libsnapshot/snapuserd/snapuserd_server.h
@@ -42,6 +42,7 @@
     STOP,
     DELETE,
     DETACH,
+    SUPPORTS,
     INVALID,
 };
 
@@ -93,6 +94,7 @@
   private:
     android::base::unique_fd sockfd_;
     bool terminating_;
+    volatile bool received_socket_signal_ = false;
     std::vector<struct pollfd> watched_fds_;
 
     std::mutex lock_;
@@ -100,7 +102,7 @@
     using HandlerList = std::vector<std::shared_ptr<DmUserHandler>>;
     HandlerList dm_users_;
 
-    void AddWatchedFd(android::base::borrowed_fd fd);
+    void AddWatchedFd(android::base::borrowed_fd fd, int events);
     void AcceptClient();
     bool HandleClient(android::base::borrowed_fd fd, int revents);
     bool Recv(android::base::borrowed_fd fd, std::string* data);
@@ -117,6 +119,7 @@
 
     void RunThread(std::shared_ptr<DmUserHandler> handler);
     void JoinAllThreads();
+    bool StartWithSocket(bool start_listening);
 
     // Find a DmUserHandler within a lock.
     HandlerList::iterator FindHandler(std::lock_guard<std::mutex>* proof_of_lock,
@@ -129,6 +132,8 @@
     bool Start(const std::string& socketname);
     bool Run();
     void Interrupt();
+    bool RunForSocketHandoff();
+    bool WaitForSocket();
 
     std::shared_ptr<DmUserHandler> AddHandler(const std::string& misc_name,
                                               const std::string& cow_device_path,
@@ -136,6 +141,7 @@
     bool StartHandler(const std::shared_ptr<DmUserHandler>& handler);
 
     void SetTerminating() { terminating_ = true; }
+    void ReceivedSocketSignal() { received_socket_signal_ = true; }
 };
 
 }  // namespace snapshot
diff --git a/init/init.cpp b/init/init.cpp
index a7325ca..bde8e04 100644
--- a/init/init.cpp
+++ b/init/init.cpp
@@ -725,6 +725,40 @@
     }
 }
 
+static Result<void> ConnectEarlyStageSnapuserdAction(const BuiltinArguments& args) {
+    auto pid = GetSnapuserdFirstStagePid();
+    if (!pid) {
+        return {};
+    }
+
+    auto info = GetSnapuserdFirstStageInfo();
+    if (auto iter = std::find(info.begin(), info.end(), "socket"s); iter == info.end()) {
+        // snapuserd does not support socket handoff, so exit early.
+        return {};
+    }
+
+    // Socket handoff is supported.
+    auto svc = ServiceList::GetInstance().FindService("snapuserd");
+    if (!svc) {
+        LOG(FATAL) << "Failed to find snapuserd service entry";
+    }
+
+    svc->SetShutdownCritical();
+    svc->SetStartedInFirstStage(*pid);
+
+    svc = ServiceList::GetInstance().FindService("snapuserd_proxy");
+    if (!svc) {
+        LOG(FATAL) << "Failed find snapuserd_proxy service entry, merge will never initiate";
+    }
+    if (!svc->MarkSocketPersistent("snapuserd")) {
+        LOG(FATAL) << "Could not find snapuserd socket in snapuserd_proxy service entry";
+    }
+    if (auto result = svc->Start(); !result.ok()) {
+        LOG(FATAL) << "Could not start snapuserd_proxy: " << result.error();
+    }
+    return {};
+}
+
 int SecondStageMain(int argc, char** argv) {
     if (REBOOT_BOOTLOADER_ON_PANIC) {
         InstallRebootSignalHandlers();
@@ -852,6 +886,7 @@
     am.QueueBuiltinAction(SetupCgroupsAction, "SetupCgroups");
     am.QueueBuiltinAction(SetKptrRestrictAction, "SetKptrRestrict");
     am.QueueBuiltinAction(TestPerfEventSelinuxAction, "TestPerfEventSelinux");
+    am.QueueBuiltinAction(ConnectEarlyStageSnapuserdAction, "ConnectEarlyStageSnapuserd");
     am.QueueEventTrigger("early-init");
 
     // Queue an action that waits for coldboot done so we know ueventd has set up all of /dev...
diff --git a/init/service.cpp b/init/service.cpp
index c3069f5..489dd67 100644
--- a/init/service.cpp
+++ b/init/service.cpp
@@ -269,6 +269,9 @@
 
     // Remove any socket resources we may have created.
     for (const auto& socket : sockets_) {
+        if (socket.persist) {
+            continue;
+        }
         auto path = ANDROID_SOCKET_DIR "/" + socket.name;
         unlink(path.c_str());
     }
@@ -409,9 +412,7 @@
     }
 
     bool disabled = (flags_ & (SVC_DISABLED | SVC_RESET));
-    // Starting a service removes it from the disabled or reset state and
-    // immediately takes it out of the restarting state if it was in there.
-    flags_ &= (~(SVC_DISABLED|SVC_RESTARTING|SVC_RESET|SVC_RESTART|SVC_DISABLED_START));
+    ResetFlagsForStart();
 
     // Running processes require no additional work --- if they're in the
     // process of exiting, we've ensured that they will immediately restart
@@ -622,6 +623,23 @@
     return {};
 }
 
+void Service::SetStartedInFirstStage(pid_t pid) {
+    LOG(INFO) << "adding first-stage service '" << name_ << "'...";
+
+    time_started_ = boot_clock::now();  // not accurate, but doesn't matter here
+    pid_ = pid;
+    flags_ |= SVC_RUNNING;
+    start_order_ = next_start_order_++;
+
+    NotifyStateChange("running");
+}
+
+void Service::ResetFlagsForStart() {
+    // Starting a service removes it from the disabled or reset state and
+    // immediately takes it out of the restarting state if it was in there.
+    flags_ &= ~(SVC_DISABLED | SVC_RESTARTING | SVC_RESET | SVC_RESTART | SVC_DISABLED_START);
+}
+
 Result<void> Service::StartIfNotDisabled() {
     if (!(flags_ & SVC_DISABLED)) {
         return Start();
@@ -792,5 +810,18 @@
                                      nullptr, str_args, false);
 }
 
+// This is used for snapuserd_proxy, which hands off a socket to snapuserd. It's
+// a special case to support the daemon launched in first-stage init. The persist
+// feature is not part of the init language and is only used here.
+bool Service::MarkSocketPersistent(const std::string& socket_name) {
+    for (auto& socket : sockets_) {
+        if (socket.name == socket_name) {
+            socket.persist = true;
+            return true;
+        }
+    }
+    return false;
+}
+
 }  // namespace init
 }  // namespace android
diff --git a/init/service.h b/init/service.h
index 043555f..ccf6899 100644
--- a/init/service.h
+++ b/init/service.h
@@ -99,6 +99,8 @@
     void AddReapCallback(std::function<void(const siginfo_t& siginfo)> callback) {
         reap_callbacks_.emplace_back(std::move(callback));
     }
+    void SetStartedInFirstStage(pid_t pid);
+    bool MarkSocketPersistent(const std::string& socket_name);
     size_t CheckAllCommands() const { return onrestart_.CheckAllCommands(); }
 
     static bool is_exec_service_running() { return is_exec_service_running_; }
@@ -144,6 +146,7 @@
     void StopOrReset(int how);
     void KillProcessGroup(int signal, bool report_oneshot = false);
     void SetProcessAttributesAndCaps();
+    void ResetFlagsForStart();
 
     static unsigned long next_start_order_;
     static bool is_exec_service_running_;
diff --git a/init/service_utils.h b/init/service_utils.h
index 1e0b4bd..9b65dca 100644
--- a/init/service_utils.h
+++ b/init/service_utils.h
@@ -54,6 +54,7 @@
     int perm = 0;
     std::string context;
     bool passcred = false;
+    bool persist = false;
 
     // Create() creates the named unix domain socket in /dev/socket and returns a Descriptor object.
     // It should be called when starting a service, before calling fork(), such that the socket is
diff --git a/init/snapuserd_transition.cpp b/init/snapuserd_transition.cpp
index 9a0b3b7..b8c2fd2 100644
--- a/init/snapuserd_transition.cpp
+++ b/init/snapuserd_transition.cpp
@@ -54,6 +54,7 @@
 static constexpr char kSnapuserdPath[] = "/system/bin/snapuserd";
 static constexpr char kSnapuserdFirstStagePidVar[] = "FIRST_STAGE_SNAPUSERD_PID";
 static constexpr char kSnapuserdFirstStageFdVar[] = "FIRST_STAGE_SNAPUSERD_FD";
+static constexpr char kSnapuserdFirstStageInfoVar[] = "FIRST_STAGE_SNAPUSERD_INFO";
 static constexpr char kSnapuserdLabel[] = "u:object_r:snapuserd_exec:s0";
 static constexpr char kSnapuserdSocketLabel[] = "u:object_r:snapuserd_socket:s0";
 
@@ -87,6 +88,14 @@
         _exit(127);
     }
 
+    auto client = SnapuserdClient::Connect(android::snapshot::kSnapuserdSocket, 10s);
+    if (!client) {
+        LOG(FATAL) << "Could not connect to first-stage snapuserd";
+    }
+    if (client->SupportsSecondStageSocketHandoff()) {
+        setenv(kSnapuserdFirstStageInfoVar, "socket", 1);
+    }
+
     setenv(kSnapuserdFirstStagePidVar, std::to_string(pid).c_str(), 1);
 
     LOG(INFO) << "Relaunched snapuserd with pid: " << pid;
@@ -328,5 +337,13 @@
     return GetSnapuserdFirstStagePid().has_value();
 }
 
+std::vector<std::string> GetSnapuserdFirstStageInfo() {
+    const char* pid_str = getenv(kSnapuserdFirstStageInfoVar);
+    if (!pid_str) {
+        return {};
+    }
+    return android::base::Split(pid_str, ",");
+}
+
 }  // namespace init
 }  // namespace android
diff --git a/init/snapuserd_transition.h b/init/snapuserd_transition.h
index a5ab652..62aee83 100644
--- a/init/snapuserd_transition.h
+++ b/init/snapuserd_transition.h
@@ -76,6 +76,9 @@
 // Return the pid of the first-stage instances of snapuserd, if it was started.
 std::optional<pid_t> GetSnapuserdFirstStagePid();
 
+// Return snapuserd info strings that were set during first-stage init.
+std::vector<std::string> GetSnapuserdFirstStageInfo();
+
 // Save an open fd to /system/bin (in the ramdisk) into an environment. This is
 // used to later execveat() snapuserd.
 void SaveRamdiskPathToSnapuserd();