diff --git a/init/init.cpp b/init/init.cpp
index a7325ca..bde8e04 100644
--- a/init/init.cpp
+++ b/init/init.cpp
@@ -725,6 +725,40 @@
     }
 }
 
+static Result<void> ConnectEarlyStageSnapuserdAction(const BuiltinArguments& args) {
+    auto pid = GetSnapuserdFirstStagePid();
+    if (!pid) {
+        return {};
+    }
+
+    auto info = GetSnapuserdFirstStageInfo();
+    if (auto iter = std::find(info.begin(), info.end(), "socket"s); iter == info.end()) {
+        // snapuserd does not support socket handoff, so exit early.
+        return {};
+    }
+
+    // Socket handoff is supported.
+    auto svc = ServiceList::GetInstance().FindService("snapuserd");
+    if (!svc) {
+        LOG(FATAL) << "Failed to find snapuserd service entry";
+    }
+
+    svc->SetShutdownCritical();
+    svc->SetStartedInFirstStage(*pid);
+
+    svc = ServiceList::GetInstance().FindService("snapuserd_proxy");
+    if (!svc) {
+        LOG(FATAL) << "Failed find snapuserd_proxy service entry, merge will never initiate";
+    }
+    if (!svc->MarkSocketPersistent("snapuserd")) {
+        LOG(FATAL) << "Could not find snapuserd socket in snapuserd_proxy service entry";
+    }
+    if (auto result = svc->Start(); !result.ok()) {
+        LOG(FATAL) << "Could not start snapuserd_proxy: " << result.error();
+    }
+    return {};
+}
+
 int SecondStageMain(int argc, char** argv) {
     if (REBOOT_BOOTLOADER_ON_PANIC) {
         InstallRebootSignalHandlers();
@@ -852,6 +886,7 @@
     am.QueueBuiltinAction(SetupCgroupsAction, "SetupCgroups");
     am.QueueBuiltinAction(SetKptrRestrictAction, "SetKptrRestrict");
     am.QueueBuiltinAction(TestPerfEventSelinuxAction, "TestPerfEventSelinux");
+    am.QueueBuiltinAction(ConnectEarlyStageSnapuserdAction, "ConnectEarlyStageSnapuserd");
     am.QueueEventTrigger("early-init");
 
     // Queue an action that waits for coldboot done so we know ueventd has set up all of /dev...
diff --git a/init/service.cpp b/init/service.cpp
index c3069f5..489dd67 100644
--- a/init/service.cpp
+++ b/init/service.cpp
@@ -269,6 +269,9 @@
 
     // Remove any socket resources we may have created.
     for (const auto& socket : sockets_) {
+        if (socket.persist) {
+            continue;
+        }
         auto path = ANDROID_SOCKET_DIR "/" + socket.name;
         unlink(path.c_str());
     }
@@ -409,9 +412,7 @@
     }
 
     bool disabled = (flags_ & (SVC_DISABLED | SVC_RESET));
-    // Starting a service removes it from the disabled or reset state and
-    // immediately takes it out of the restarting state if it was in there.
-    flags_ &= (~(SVC_DISABLED|SVC_RESTARTING|SVC_RESET|SVC_RESTART|SVC_DISABLED_START));
+    ResetFlagsForStart();
 
     // Running processes require no additional work --- if they're in the
     // process of exiting, we've ensured that they will immediately restart
@@ -622,6 +623,23 @@
     return {};
 }
 
+void Service::SetStartedInFirstStage(pid_t pid) {
+    LOG(INFO) << "adding first-stage service '" << name_ << "'...";
+
+    time_started_ = boot_clock::now();  // not accurate, but doesn't matter here
+    pid_ = pid;
+    flags_ |= SVC_RUNNING;
+    start_order_ = next_start_order_++;
+
+    NotifyStateChange("running");
+}
+
+void Service::ResetFlagsForStart() {
+    // Starting a service removes it from the disabled or reset state and
+    // immediately takes it out of the restarting state if it was in there.
+    flags_ &= ~(SVC_DISABLED | SVC_RESTARTING | SVC_RESET | SVC_RESTART | SVC_DISABLED_START);
+}
+
 Result<void> Service::StartIfNotDisabled() {
     if (!(flags_ & SVC_DISABLED)) {
         return Start();
@@ -792,5 +810,18 @@
                                      nullptr, str_args, false);
 }
 
+// This is used for snapuserd_proxy, which hands off a socket to snapuserd. It's
+// a special case to support the daemon launched in first-stage init. The persist
+// feature is not part of the init language and is only used here.
+bool Service::MarkSocketPersistent(const std::string& socket_name) {
+    for (auto& socket : sockets_) {
+        if (socket.name == socket_name) {
+            socket.persist = true;
+            return true;
+        }
+    }
+    return false;
+}
+
 }  // namespace init
 }  // namespace android
diff --git a/init/service.h b/init/service.h
index 043555f..ccf6899 100644
--- a/init/service.h
+++ b/init/service.h
@@ -99,6 +99,8 @@
     void AddReapCallback(std::function<void(const siginfo_t& siginfo)> callback) {
         reap_callbacks_.emplace_back(std::move(callback));
     }
+    void SetStartedInFirstStage(pid_t pid);
+    bool MarkSocketPersistent(const std::string& socket_name);
     size_t CheckAllCommands() const { return onrestart_.CheckAllCommands(); }
 
     static bool is_exec_service_running() { return is_exec_service_running_; }
@@ -144,6 +146,7 @@
     void StopOrReset(int how);
     void KillProcessGroup(int signal, bool report_oneshot = false);
     void SetProcessAttributesAndCaps();
+    void ResetFlagsForStart();
 
     static unsigned long next_start_order_;
     static bool is_exec_service_running_;
diff --git a/init/service_utils.h b/init/service_utils.h
index 1e0b4bd..9b65dca 100644
--- a/init/service_utils.h
+++ b/init/service_utils.h
@@ -54,6 +54,7 @@
     int perm = 0;
     std::string context;
     bool passcred = false;
+    bool persist = false;
 
     // Create() creates the named unix domain socket in /dev/socket and returns a Descriptor object.
     // It should be called when starting a service, before calling fork(), such that the socket is
diff --git a/init/snapuserd_transition.cpp b/init/snapuserd_transition.cpp
index 9a0b3b7..b8c2fd2 100644
--- a/init/snapuserd_transition.cpp
+++ b/init/snapuserd_transition.cpp
@@ -54,6 +54,7 @@
 static constexpr char kSnapuserdPath[] = "/system/bin/snapuserd";
 static constexpr char kSnapuserdFirstStagePidVar[] = "FIRST_STAGE_SNAPUSERD_PID";
 static constexpr char kSnapuserdFirstStageFdVar[] = "FIRST_STAGE_SNAPUSERD_FD";
+static constexpr char kSnapuserdFirstStageInfoVar[] = "FIRST_STAGE_SNAPUSERD_INFO";
 static constexpr char kSnapuserdLabel[] = "u:object_r:snapuserd_exec:s0";
 static constexpr char kSnapuserdSocketLabel[] = "u:object_r:snapuserd_socket:s0";
 
@@ -87,6 +88,14 @@
         _exit(127);
     }
 
+    auto client = SnapuserdClient::Connect(android::snapshot::kSnapuserdSocket, 10s);
+    if (!client) {
+        LOG(FATAL) << "Could not connect to first-stage snapuserd";
+    }
+    if (client->SupportsSecondStageSocketHandoff()) {
+        setenv(kSnapuserdFirstStageInfoVar, "socket", 1);
+    }
+
     setenv(kSnapuserdFirstStagePidVar, std::to_string(pid).c_str(), 1);
 
     LOG(INFO) << "Relaunched snapuserd with pid: " << pid;
@@ -328,5 +337,13 @@
     return GetSnapuserdFirstStagePid().has_value();
 }
 
+std::vector<std::string> GetSnapuserdFirstStageInfo() {
+    const char* pid_str = getenv(kSnapuserdFirstStageInfoVar);
+    if (!pid_str) {
+        return {};
+    }
+    return android::base::Split(pid_str, ",");
+}
+
 }  // namespace init
 }  // namespace android
diff --git a/init/snapuserd_transition.h b/init/snapuserd_transition.h
index a5ab652..62aee83 100644
--- a/init/snapuserd_transition.h
+++ b/init/snapuserd_transition.h
@@ -76,6 +76,9 @@
 // Return the pid of the first-stage instances of snapuserd, if it was started.
 std::optional<pid_t> GetSnapuserdFirstStagePid();
 
+// Return snapuserd info strings that were set during first-stage init.
+std::vector<std::string> GetSnapuserdFirstStageInfo();
+
 // Save an open fd to /system/bin (in the ramdisk) into an environment. This is
 // used to later execveat() snapuserd.
 void SaveRamdiskPathToSnapuserd();
