Merge changes from topic "snapuserd-selinux"

* changes:
  libsnapshot: Fix tests that depend on PrepareOneSnapshot().
  libsnapshot: Ensure dm-user devices are destroyed after a merge.
  libsnapshot: Fix tests for mapping snapshots in first-stage init.
  init: Add an selinux transition for snapuserd.
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h b/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
index f8d3cdc..397ff2e 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/snapshot.h
@@ -306,13 +306,17 @@
     // Helper function for second stage init to restorecon on the rollback indicator.
     static std::string GetGlobalRollbackIndicatorPath();
 
-    // Initiate the transition from first-stage to second-stage snapuserd. This
-    // process involves re-creating the dm-user table entries for each device,
-    // so that they connect to the new daemon. Once all new tables have been
-    // activated, we ask the first-stage daemon to cleanly exit.
-    //
-    // The caller must pass a function which starts snapuserd.
-    bool PerformSecondStageTransition();
+    // Detach dm-user devices from the current snapuserd, and populate
+    // |snapuserd_argv| with the necessary arguments to restart snapuserd
+    // and reattach them.
+    bool DetachSnapuserdForSelinux(std::vector<std::string>* snapuserd_argv);
+
+    // Perform the transition from the selinux stage of snapuserd into the
+    // second-stage of snapuserd. This process involves re-creating the dm-user
+    // table entries for each device, so that they connect to the new daemon.
+    // Once all new tables have been activated, we ask the first-stage daemon
+    // to cleanly exit.
+    bool PerformSecondStageInitTransition();
 
     // ISnapshotManager overrides.
     bool BeginUpdate() override;
@@ -626,6 +630,9 @@
     // The reverse of MapPartitionWithSnapshot.
     bool UnmapPartitionWithSnapshot(LockedFile* lock, const std::string& target_partition_name);
 
+    // Unmap a dm-user device through snapuserd.
+    bool UnmapDmUserDevice(const std::string& snapshot_name);
+
     // If there isn't a previous update, return true. |needs_merge| is set to false.
     // If there is a previous update but the device has not boot into it, tries to cancel the
     //   update and delete any snapshots. Return true if successful. |needs_merge| is set to false.
@@ -693,6 +700,19 @@
     // returns true.
     bool WaitForDevice(const std::string& device, std::chrono::milliseconds timeout_ms);
 
+    enum class InitTransition { SELINUX_DETACH, SECOND_STAGE };
+
+    // Initiate the transition from first-stage to second-stage snapuserd. This
+    // process involves re-creating the dm-user table entries for each device,
+    // so that they connect to the new daemon. Once all new tables have been
+    // activated, we ask the first-stage daemon to cleanly exit.
+    //
+    // If the mode is SELINUX_DETACH, snapuserd_argv must be non-null and will
+    // be populated with a list of snapuserd arguments to pass to execve(). It
+    // is otherwise ignored.
+    bool PerformInitTransition(InitTransition transition,
+                               std::vector<std::string>* snapuserd_argv = nullptr);
+
     std::string gsid_dir_;
     std::string metadata_dir_;
     std::unique_ptr<IDeviceInfo> device_;
diff --git a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h
index aa9ba6e..1dab361 100644
--- a/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h
+++ b/fs_mgr/libsnapshot/include/libsnapshot/snapuserd_client.h
@@ -32,8 +32,6 @@
 
 static constexpr char kSnapuserdSocket[] = "snapuserd";
 
-static constexpr char kSnapuserdFirstStagePidVar[] = "FIRST_STAGE_SNAPUSERD_PID";
-
 // Ensure that the second-stage daemon for snapuserd is running.
 bool EnsureSnapuserdStarted();
 
diff --git a/fs_mgr/libsnapshot/snapshot.cpp b/fs_mgr/libsnapshot/snapshot.cpp
index 04c3ccc..f0646fc 100644
--- a/fs_mgr/libsnapshot/snapshot.cpp
+++ b/fs_mgr/libsnapshot/snapshot.cpp
@@ -1247,6 +1247,10 @@
         return false;
     }
 
+    if (status.compression_enabled()) {
+        UnmapDmUserDevice(name);
+    }
+
     // Cleanup the base device as well, since it is no longer used. This does
     // not block cleanup.
     auto base_name = GetBaseDeviceName(name);
@@ -1291,12 +1295,13 @@
     return RemoveAllUpdateState(lock, before_cancel);
 }
 
-bool SnapshotManager::PerformSecondStageTransition() {
-    LOG(INFO) << "Performing second-stage transition for snapuserd.";
+bool SnapshotManager::PerformInitTransition(InitTransition transition,
+                                            std::vector<std::string>* snapuserd_argv) {
+    LOG(INFO) << "Performing transition for snapuserd.";
 
     // Don't use EnsuerSnapuserdConnected() because this is called from init,
     // and attempting to do so will deadlock.
-    if (!snapuserd_client_) {
+    if (!snapuserd_client_ && transition != InitTransition::SELINUX_DETACH) {
         snapuserd_client_ = SnapuserdClient::Connect(kSnapuserdSocket, 10s);
         if (!snapuserd_client_) {
             LOG(ERROR) << "Unable to connect to snapuserd";
@@ -1343,6 +1348,9 @@
         }
 
         auto misc_name = user_cow_name;
+        if (transition == InitTransition::SELINUX_DETACH) {
+            misc_name += "-selinux";
+        }
 
         DmTable table;
         table.Emplace<DmTargetUser>(0, target.spec.length, misc_name);
@@ -1378,6 +1386,17 @@
             continue;
         }
 
+        if (transition == InitTransition::SELINUX_DETACH) {
+            auto message = misc_name + "," + cow_image_device + "," + backing_device;
+            snapuserd_argv->emplace_back(std::move(message));
+
+            // Do not attempt to connect to the new snapuserd yet, it hasn't
+            // been started. We do however want to wait for the misc device
+            // to have been created.
+            ok_cows++;
+            continue;
+        }
+
         uint64_t base_sectors =
                 snapuserd_client_->InitDmUserCow(misc_name, cow_image_device, backing_device);
         if (base_sectors == 0) {
@@ -2048,27 +2067,8 @@
 
     auto& dm = DeviceMapper::Instance();
 
-    auto dm_user_name = GetDmUserCowName(name);
-    if (IsCompressionEnabled() && dm.GetState(dm_user_name) != DmDeviceState::INVALID) {
-        if (!EnsureSnapuserdConnected()) {
-            return false;
-        }
-        if (!dm.DeleteDeviceIfExists(dm_user_name)) {
-            LOG(ERROR) << "Cannot unmap " << dm_user_name;
-            return false;
-        }
-
-        if (!snapuserd_client_->WaitForDeviceDelete(dm_user_name)) {
-            LOG(ERROR) << "Failed to wait for " << dm_user_name << " control device to delete";
-            return false;
-        }
-
-        // Ensure the control device is gone so we don't run into ABA problems.
-        auto control_device = "/dev/dm-user/" + dm_user_name;
-        if (!android::fs_mgr::WaitForFileDeleted(control_device, 10s)) {
-            LOG(ERROR) << "Timed out waiting for " << control_device << " to unlink";
-            return false;
-        }
+    if (IsCompressionEnabled() && !UnmapDmUserDevice(name)) {
+        return false;
     }
 
     auto cow_name = GetCowName(name);
@@ -2085,6 +2085,37 @@
     return true;
 }
 
+bool SnapshotManager::UnmapDmUserDevice(const std::string& snapshot_name) {
+    auto& dm = DeviceMapper::Instance();
+
+    if (!EnsureSnapuserdConnected()) {
+        return false;
+    }
+
+    auto dm_user_name = GetDmUserCowName(snapshot_name);
+    if (dm.GetState(dm_user_name) == DmDeviceState::INVALID) {
+        return true;
+    }
+
+    if (!dm.DeleteDeviceIfExists(dm_user_name)) {
+        LOG(ERROR) << "Cannot unmap " << dm_user_name;
+        return false;
+    }
+
+    if (!snapuserd_client_->WaitForDeviceDelete(dm_user_name)) {
+        LOG(ERROR) << "Failed to wait for " << dm_user_name << " control device to delete";
+        return false;
+    }
+
+    // Ensure the control device is gone so we don't run into ABA problems.
+    auto control_device = "/dev/dm-user/" + dm_user_name;
+    if (!android::fs_mgr::WaitForFileDeleted(control_device, 10s)) {
+        LOG(ERROR) << "Timed out waiting for " << control_device << " to unlink";
+        return false;
+    }
+    return true;
+}
+
 bool SnapshotManager::MapAllSnapshots(const std::chrono::milliseconds& timeout_ms) {
     auto lock = LockExclusive();
     if (!lock) return false;
@@ -3311,5 +3342,13 @@
     return status.state() != UpdateState::None && status.compression_enabled();
 }
 
+bool SnapshotManager::DetachSnapuserdForSelinux(std::vector<std::string>* snapuserd_argv) {
+    return PerformInitTransition(InitTransition::SELINUX_DETACH, snapuserd_argv);
+}
+
+bool SnapshotManager::PerformSecondStageInitTransition() {
+    return PerformInitTransition(InitTransition::SECOND_STAGE);
+}
+
 }  // namespace snapshot
 }  // namespace android
diff --git a/fs_mgr/libsnapshot/snapshot_test.cpp b/fs_mgr/libsnapshot/snapshot_test.cpp
index c6a1786..bb44425 100644
--- a/fs_mgr/libsnapshot/snapshot_test.cpp
+++ b/fs_mgr/libsnapshot/snapshot_test.cpp
@@ -146,6 +146,7 @@
                 "base-device",
                 "test_partition_b",
                 "test_partition_b-base",
+                "test_partition_b-base",
         };
         for (const auto& partition : partitions) {
             DeleteDevice(partition);
@@ -180,12 +181,22 @@
     }
 
     // If |path| is non-null, the partition will be mapped after creation.
-    bool CreatePartition(const std::string& name, uint64_t size, std::string* path = nullptr) {
+    bool CreatePartition(const std::string& name, uint64_t size, std::string* path = nullptr,
+                         const std::optional<std::string> group = {}) {
         TestPartitionOpener opener(fake_super);
         auto builder = MetadataBuilder::New(opener, "super", 0);
         if (!builder) return false;
 
-        auto partition = builder->AddPartition(name, 0);
+        std::string partition_group = std::string(android::fs_mgr::kDefaultGroup);
+        if (group) {
+            partition_group = *group;
+        }
+        return CreatePartition(builder.get(), name, size, path, partition_group);
+    }
+
+    bool CreatePartition(MetadataBuilder* builder, const std::string& name, uint64_t size,
+                         std::string* path, const std::string& group) {
+        auto partition = builder->AddPartition(name, group, 0);
         if (!partition) return false;
         if (!builder->ResizePartition(partition, size)) {
             return false;
@@ -194,6 +205,8 @@
         // Update the source slot.
         auto metadata = builder->Export();
         if (!metadata) return false;
+
+        TestPartitionOpener opener(fake_super);
         if (!UpdatePartitionTable(opener, "super", *metadata.get(), 0)) {
             return false;
         }
@@ -210,39 +223,54 @@
         return CreateLogicalPartition(params, path);
     }
 
-    bool MapUpdatePartitions() {
+    AssertionResult MapUpdateSnapshot(const std::string& name,
+                                      std::unique_ptr<ICowWriter>* writer) {
         TestPartitionOpener opener(fake_super);
-        auto builder = MetadataBuilder::NewForUpdate(opener, "super", 0, 1);
-        if (!builder) return false;
+        CreateLogicalPartitionParams params{
+                .block_device = fake_super,
+                .metadata_slot = 1,
+                .partition_name = name,
+                .timeout_ms = 10s,
+                .partition_opener = &opener,
+        };
 
-        auto metadata = builder->Export();
-        if (!metadata) return false;
-
-        // Update the destination slot, mark it as updated.
-        if (!UpdatePartitionTable(opener, "super", *metadata.get(), 1)) {
-            return false;
+        auto result = sm->OpenSnapshotWriter(params, {});
+        if (!result) {
+            return AssertionFailure() << "Cannot open snapshot for writing: " << name;
+        }
+        if (!result->Initialize()) {
+            return AssertionFailure() << "Cannot initialize snapshot for writing: " << name;
         }
 
-        for (const auto& partition : metadata->partitions) {
-            CreateLogicalPartitionParams params = {
-                    .block_device = fake_super,
-                    .metadata = metadata.get(),
-                    .partition = &partition,
-                    .force_writable = true,
-                    .timeout_ms = 10s,
-                    .device_name = GetPartitionName(partition) + "-base",
-            };
-            std::string ignore_path;
-            if (!CreateLogicalPartition(params, &ignore_path)) {
-                return false;
-            }
+        if (writer) {
+            *writer = std::move(result);
         }
-        return true;
+        return AssertionSuccess();
+    }
+
+    AssertionResult MapUpdateSnapshot(const std::string& name, std::string* path) {
+        TestPartitionOpener opener(fake_super);
+        CreateLogicalPartitionParams params{
+                .block_device = fake_super,
+                .metadata_slot = 1,
+                .partition_name = name,
+                .timeout_ms = 10s,
+                .partition_opener = &opener,
+        };
+
+        auto result = sm->MapUpdateSnapshot(params, path);
+        if (!result) {
+            return AssertionFailure() << "Cannot open snapshot for writing: " << name;
+        }
+        return AssertionSuccess();
     }
 
     AssertionResult DeleteSnapshotDevice(const std::string& snapshot) {
         AssertionResult res = AssertionSuccess();
         if (!(res = DeleteDevice(snapshot))) return res;
+        if (!sm->UnmapDmUserDevice(snapshot)) {
+            return AssertionFailure() << "Cannot delete dm-user device for " << snapshot;
+        }
         if (!(res = DeleteDevice(snapshot + "-inner"))) return res;
         if (!(res = DeleteDevice(snapshot + "-cow"))) return res;
         if (!image_manager_->UnmapImageIfExists(snapshot + "-cow-img")) {
@@ -289,37 +317,55 @@
 
     // Prepare A/B slot for a partition named "test_partition".
     AssertionResult PrepareOneSnapshot(uint64_t device_size,
-                                       std::string* out_snap_device = nullptr) {
-        std::string base_device, cow_device, snap_device;
-        if (!CreatePartition("test_partition_a", device_size)) {
-            return AssertionFailure();
+                                       std::unique_ptr<ICowWriter>* writer = nullptr) {
+        lock_ = nullptr;
+
+        DeltaArchiveManifest manifest;
+
+        auto group = manifest.mutable_dynamic_partition_metadata()->add_groups();
+        group->set_name("group");
+        group->set_size(device_size * 2);
+        group->add_partition_names("test_partition");
+
+        auto pu = manifest.add_partitions();
+        pu->set_partition_name("test_partition");
+        pu->set_estimate_cow_size(device_size);
+        SetSize(pu, device_size);
+
+        auto extent = pu->add_operations()->add_dst_extents();
+        extent->set_start_block(0);
+        if (device_size) {
+            extent->set_num_blocks(device_size / manifest.block_size());
         }
-        if (!MapUpdatePartitions()) {
-            return AssertionFailure();
+
+        TestPartitionOpener opener(fake_super);
+        auto builder = MetadataBuilder::New(opener, "super", 0);
+        if (!builder) {
+            return AssertionFailure() << "Failed to open MetadataBuilder";
         }
-        if (!dm_.GetDmDevicePathByName("test_partition_b-base", &base_device)) {
-            return AssertionFailure();
+        builder->AddGroup("group_a", 16_GiB);
+        builder->AddGroup("group_b", 16_GiB);
+        if (!CreatePartition(builder.get(), "test_partition_a", device_size, nullptr, "group_a")) {
+            return AssertionFailure() << "Failed create test_partition_a";
         }
-        SnapshotStatus status;
-        status.set_name("test_partition_b");
-        status.set_device_size(device_size);
-        status.set_snapshot_size(device_size);
-        status.set_cow_file_size(device_size);
-        if (!sm->CreateSnapshot(lock_.get(), &status)) {
-            return AssertionFailure();
+
+        if (!sm->CreateUpdateSnapshots(manifest)) {
+            return AssertionFailure() << "Failed to create update snapshots";
         }
-        if (!CreateCowImage("test_partition_b")) {
-            return AssertionFailure();
+
+        if (writer) {
+            auto res = MapUpdateSnapshot("test_partition_b", writer);
+            if (!res) {
+                return res;
+            }
+        } else if (!IsCompressionEnabled()) {
+            std::string ignore;
+            if (!MapUpdateSnapshot("test_partition_b", &ignore)) {
+                return AssertionFailure() << "Failed to map test_partition_b";
+            }
         }
-        if (!MapCowImage("test_partition_b", 10s, &cow_device)) {
-            return AssertionFailure();
-        }
-        if (!sm->MapSnapshot(lock_.get(), "test_partition_b", base_device, cow_device, 10s,
-                             &snap_device)) {
-            return AssertionFailure();
-        }
-        if (out_snap_device) {
-            *out_snap_device = std::move(snap_device);
+        if (!AcquireLock()) {
+            return AssertionFailure() << "Failed to acquire lock";
         }
         return AssertionSuccess();
     }
@@ -328,20 +374,37 @@
     AssertionResult SimulateReboot() {
         lock_ = nullptr;
         if (!sm->FinishedSnapshotWrites(false)) {
-            return AssertionFailure();
+            return AssertionFailure() << "Failed to finish snapshot writes";
         }
-        if (!dm_.DeleteDevice("test_partition_b")) {
-            return AssertionFailure();
+        if (!sm->UnmapUpdateSnapshot("test_partition_b")) {
+            return AssertionFailure() << "Failed to unmap COW for test_partition_b";
         }
-        if (!DestroyLogicalPartition("test_partition_b-base")) {
-            return AssertionFailure();
+        if (!dm_.DeleteDeviceIfExists("test_partition_b")) {
+            return AssertionFailure() << "Failed to delete test_partition_b";
         }
-        if (!sm->UnmapCowImage("test_partition_b")) {
-            return AssertionFailure();
+        if (!dm_.DeleteDeviceIfExists("test_partition_b-base")) {
+            return AssertionFailure() << "Failed to destroy test_partition_b-base";
         }
         return AssertionSuccess();
     }
 
+    std::unique_ptr<SnapshotManager> NewManagerForFirstStageMount(
+            const std::string& slot_suffix = "_a") {
+        auto info = new TestDeviceInfo(fake_super, slot_suffix);
+        return NewManagerForFirstStageMount(info);
+    }
+
+    std::unique_ptr<SnapshotManager> NewManagerForFirstStageMount(TestDeviceInfo* info) {
+        auto init = SnapshotManager::NewForFirstStageMount(info);
+        if (!init) {
+            return nullptr;
+        }
+        init->SetUeventRegenCallback([](const std::string& device) -> bool {
+            return android::fs_mgr::WaitForFile(device, snapshot_timeout_);
+        });
+        return init;
+    }
+
     static constexpr std::chrono::milliseconds snapshot_timeout_ = 5s;
     DeviceMapper& dm_;
     std::unique_ptr<SnapshotManager::LockedFile> lock_;
@@ -439,8 +502,7 @@
 TEST_F(SnapshotTest, CleanFirstStageMount) {
     // If there's no update in progress, there should be no first-stage mount
     // needed.
-    TestDeviceInfo* info = new TestDeviceInfo(fake_super);
-    auto sm = SnapshotManager::NewForFirstStageMount(info);
+    auto sm = NewManagerForFirstStageMount();
     ASSERT_NE(sm, nullptr);
     ASSERT_FALSE(sm->NeedSnapshotsInFirstStageMount());
 }
@@ -449,8 +511,7 @@
     ASSERT_TRUE(sm->FinishedSnapshotWrites(false));
 
     // We didn't change the slot, so we shouldn't need snapshots.
-    TestDeviceInfo* info = new TestDeviceInfo(fake_super);
-    auto sm = SnapshotManager::NewForFirstStageMount(info);
+    auto sm = NewManagerForFirstStageMount();
     ASSERT_NE(sm, nullptr);
     ASSERT_FALSE(sm->NeedSnapshotsInFirstStageMount());
 
@@ -462,32 +523,30 @@
     ASSERT_TRUE(AcquireLock());
 
     static const uint64_t kDeviceSize = 1024 * 1024;
-    std::string snap_device;
-    ASSERT_TRUE(PrepareOneSnapshot(kDeviceSize, &snap_device));
 
-    std::string test_string = "This is a test string.";
-    {
-        unique_fd fd(open(snap_device.c_str(), O_RDWR | O_CLOEXEC | O_SYNC));
-        ASSERT_GE(fd, 0);
-        ASSERT_TRUE(android::base::WriteFully(fd, test_string.data(), test_string.size()));
-    }
-
-    // Note: we know there is no inner/outer dm device since we didn't request
-    // a linear segment.
-    DeviceMapper::TargetInfo target;
-    ASSERT_TRUE(sm->IsSnapshotDevice("test_partition_b", &target));
-    ASSERT_EQ(DeviceMapper::GetTargetType(target.spec), "snapshot");
+    std::unique_ptr<ICowWriter> writer;
+    ASSERT_TRUE(PrepareOneSnapshot(kDeviceSize, &writer));
 
     // Release the lock.
     lock_ = nullptr;
 
+    std::string test_string = "This is a test string.";
+    test_string.resize(writer->options().block_size);
+    ASSERT_TRUE(writer->AddRawBlocks(0, test_string.data(), test_string.size()));
+    ASSERT_TRUE(writer->Finalize());
+    writer = nullptr;
+
     // Done updating.
     ASSERT_TRUE(sm->FinishedSnapshotWrites(false));
 
+    ASSERT_TRUE(sm->UnmapUpdateSnapshot("test_partition_b"));
+
     test_device->set_slot_suffix("_b");
+    ASSERT_TRUE(sm->CreateLogicalAndSnapshotPartitions("super", snapshot_timeout_));
     ASSERT_TRUE(sm->InitiateMerge());
 
     // The device should have been switched to a snapshot-merge target.
+    DeviceMapper::TargetInfo target;
     ASSERT_TRUE(sm->IsSnapshotDevice("test_partition_b", &target));
     ASSERT_EQ(DeviceMapper::GetTargetType(target.spec), "snapshot-merge");
 
@@ -503,7 +562,7 @@
     // Test that we can read back the string we wrote to the snapshot. Note
     // that the base device is gone now. |snap_device| contains the correct
     // partition.
-    unique_fd fd(open(snap_device.c_str(), O_RDONLY | O_CLOEXEC));
+    unique_fd fd(open("/dev/block/mapper/test_partition_b", O_RDONLY | O_CLOEXEC));
     ASSERT_GE(fd, 0);
 
     std::string buffer(test_string.size(), '\0');
@@ -518,7 +577,7 @@
     ASSERT_TRUE(PrepareOneSnapshot(kDeviceSize));
     ASSERT_TRUE(SimulateReboot());
 
-    auto init = SnapshotManager::NewForFirstStageMount(new TestDeviceInfo(fake_super, "_b"));
+    auto init = NewManagerForFirstStageMount("_b");
     ASSERT_NE(init, nullptr);
     ASSERT_TRUE(init->NeedSnapshotsInFirstStageMount());
     ASSERT_TRUE(init->CreateLogicalAndSnapshotPartitions("super", snapshot_timeout_));
@@ -547,7 +606,7 @@
     FormatFakeSuper();
     ASSERT_TRUE(CreatePartition("test_partition_b", kDeviceSize));
 
-    auto init = SnapshotManager::NewForFirstStageMount(new TestDeviceInfo(fake_super, "_b"));
+    auto init = NewManagerForFirstStageMount("_b");
     ASSERT_NE(init, nullptr);
     ASSERT_TRUE(init->NeedSnapshotsInFirstStageMount());
     ASSERT_TRUE(init->CreateLogicalAndSnapshotPartitions("super", snapshot_timeout_));
@@ -574,7 +633,7 @@
     ASSERT_TRUE(PrepareOneSnapshot(kDeviceSize));
     ASSERT_TRUE(SimulateReboot());
 
-    auto init = SnapshotManager::NewForFirstStageMount(new TestDeviceInfo(fake_super, "_b"));
+    auto init = NewManagerForFirstStageMount("_b");
     ASSERT_NE(init, nullptr);
     ASSERT_TRUE(init->NeedSnapshotsInFirstStageMount());
     ASSERT_TRUE(init->CreateLogicalAndSnapshotPartitions("super", snapshot_timeout_));
@@ -899,47 +958,7 @@
         return AssertionSuccess();
     }
 
-    AssertionResult MapUpdateSnapshot(const std::string& name,
-                                      std::unique_ptr<ICowWriter>* writer) {
-        CreateLogicalPartitionParams params{
-                .block_device = fake_super,
-                .metadata_slot = 1,
-                .partition_name = name,
-                .timeout_ms = 10s,
-                .partition_opener = opener_.get(),
-        };
-
-        auto result = sm->OpenSnapshotWriter(params, {});
-        if (!result) {
-            return AssertionFailure() << "Cannot open snapshot for writing: " << name;
-        }
-        if (!result->Initialize()) {
-            return AssertionFailure() << "Cannot initialize snapshot for writing: " << name;
-        }
-
-        if (writer) {
-            *writer = std::move(result);
-        }
-        return AssertionSuccess();
-    }
-
-    AssertionResult MapUpdateSnapshot(const std::string& name, std::string* path) {
-        CreateLogicalPartitionParams params{
-                .block_device = fake_super,
-                .metadata_slot = 1,
-                .partition_name = name,
-                .timeout_ms = 10s,
-                .partition_opener = opener_.get(),
-        };
-
-        auto result = sm->MapUpdateSnapshot(params, path);
-        if (!result) {
-            return AssertionFailure() << "Cannot open snapshot for writing: " << name;
-        }
-        return AssertionSuccess();
-    }
-
-    AssertionResult MapUpdateSnapshot(const std::string& name) {
+    AssertionResult MapOneUpdateSnapshot(const std::string& name) {
         if (IsCompressionEnabled()) {
             std::unique_ptr<ICowWriter> writer;
             return MapUpdateSnapshot(name, &writer);
@@ -983,7 +1002,7 @@
     AssertionResult MapUpdateSnapshots(const std::vector<std::string>& names = {"sys_b", "vnd_b",
                                                                                 "prd_b"}) {
         for (const auto& name : names) {
-            auto res = MapUpdateSnapshot(name);
+            auto res = MapOneUpdateSnapshot(name);
             if (!res) {
                 return res;
             }
@@ -1070,7 +1089,7 @@
     ASSERT_TRUE(UnmapAll());
 
     // After reboot, init does first stage mount.
-    auto init = SnapshotManager::NewForFirstStageMount(new TestDeviceInfo(fake_super, "_b"));
+    auto init = NewManagerForFirstStageMount("_b");
     ASSERT_NE(init, nullptr);
     ASSERT_TRUE(init->NeedSnapshotsInFirstStageMount());
     ASSERT_TRUE(init->CreateLogicalAndSnapshotPartitions("super", snapshot_timeout_));
@@ -1202,7 +1221,7 @@
     ASSERT_TRUE(UnmapAll());
 
     // After reboot, init does first stage mount.
-    auto init = SnapshotManager::NewForFirstStageMount(new TestDeviceInfo(fake_super, "_b"));
+    auto init = NewManagerForFirstStageMount("_b");
     ASSERT_NE(init, nullptr);
     ASSERT_TRUE(init->NeedSnapshotsInFirstStageMount());
     ASSERT_TRUE(init->CreateLogicalAndSnapshotPartitions("super", snapshot_timeout_));
@@ -1214,7 +1233,7 @@
 
     // Simulate shutting down the device again.
     ASSERT_TRUE(UnmapAll());
-    init = SnapshotManager::NewForFirstStageMount(new TestDeviceInfo(fake_super, "_a"));
+    init = NewManagerForFirstStageMount("_a");
     ASSERT_NE(init, nullptr);
     ASSERT_FALSE(init->NeedSnapshotsInFirstStageMount());
     ASSERT_TRUE(init->CreateLogicalAndSnapshotPartitions("super", snapshot_timeout_));
@@ -1251,7 +1270,7 @@
     ASSERT_TRUE(UnmapAll());
 
     // After reboot, init does first stage mount.
-    auto init = SnapshotManager::NewForFirstStageMount(new TestDeviceInfo(fake_super, "_b"));
+    auto init = NewManagerForFirstStageMount("_b");
     ASSERT_NE(init, nullptr);
     ASSERT_TRUE(init->NeedSnapshotsInFirstStageMount());
     ASSERT_TRUE(init->CreateLogicalAndSnapshotPartitions("super", snapshot_timeout_));
@@ -1387,8 +1406,8 @@
     ASSERT_TRUE(UnmapAll());
 
     // After reboot, init does first stage mount.
-    // Normally we should use NewForFirstStageMount, but if so, "gsid.mapped_image.sys_b-cow-img"
-    // won't be set.
+    // Normally we should use NewManagerForFirstStageMount, but if so,
+    // "gsid.mapped_image.sys_b-cow-img" won't be set.
     auto init = SnapshotManager::New(new TestDeviceInfo(fake_super, "_b"));
     ASSERT_NE(init, nullptr);
     ASSERT_TRUE(init->CreateLogicalAndSnapshotPartitions("super", snapshot_timeout_));
@@ -1495,7 +1514,7 @@
     ASSERT_TRUE(UnmapAll());
 
     // After reboot, init does first stage mount.
-    auto init = SnapshotManager::NewForFirstStageMount(new TestDeviceInfo(fake_super, "_b"));
+    auto init = NewManagerForFirstStageMount("_b");
     ASSERT_NE(init, nullptr);
     ASSERT_TRUE(init->NeedSnapshotsInFirstStageMount());
     ASSERT_TRUE(init->CreateLogicalAndSnapshotPartitions("super", snapshot_timeout_));
@@ -1509,7 +1528,7 @@
     // Simulate a reboot into recovery.
     auto test_device = std::make_unique<TestDeviceInfo>(fake_super, "_b");
     test_device->set_recovery(true);
-    new_sm = SnapshotManager::NewForFirstStageMount(test_device.release());
+    new_sm = NewManagerForFirstStageMount(test_device.release());
 
     ASSERT_TRUE(new_sm->HandleImminentDataWipe());
     ASSERT_EQ(new_sm->GetUpdateState(), UpdateState::None);
@@ -1527,7 +1546,7 @@
     ASSERT_TRUE(UnmapAll());
 
     // After reboot, init does first stage mount.
-    auto init = SnapshotManager::NewForFirstStageMount(new TestDeviceInfo(fake_super, "_b"));
+    auto init = NewManagerForFirstStageMount("_b");
     ASSERT_NE(init, nullptr);
     ASSERT_TRUE(init->NeedSnapshotsInFirstStageMount());
     ASSERT_TRUE(init->CreateLogicalAndSnapshotPartitions("super", snapshot_timeout_));
@@ -1541,7 +1560,7 @@
     // Simulate a reboot into recovery.
     auto test_device = std::make_unique<TestDeviceInfo>(fake_super, "_b");
     test_device->set_recovery(true);
-    new_sm = SnapshotManager::NewForFirstStageMount(test_device.release());
+    new_sm = NewManagerForFirstStageMount(test_device.release());
 
     ASSERT_TRUE(new_sm->FinishMergeInRecovery());
 
@@ -1551,12 +1570,12 @@
 
     // Finish the merge in a normal boot.
     test_device = std::make_unique<TestDeviceInfo>(fake_super, "_b");
-    init = SnapshotManager::NewForFirstStageMount(test_device.release());
+    init = NewManagerForFirstStageMount(test_device.release());
     ASSERT_TRUE(init->CreateLogicalAndSnapshotPartitions("super", snapshot_timeout_));
     init = nullptr;
 
     test_device = std::make_unique<TestDeviceInfo>(fake_super, "_b");
-    new_sm = SnapshotManager::NewForFirstStageMount(test_device.release());
+    new_sm = NewManagerForFirstStageMount(test_device.release());
     ASSERT_EQ(new_sm->ProcessUpdateState(), UpdateState::MergeCompleted);
     ASSERT_EQ(new_sm->ProcessUpdateState(), UpdateState::None);
 }
@@ -1575,7 +1594,7 @@
     // Simulate a reboot into recovery.
     auto test_device = new TestDeviceInfo(fake_super, "_b");
     test_device->set_recovery(true);
-    auto new_sm = SnapshotManager::NewForFirstStageMount(test_device);
+    auto new_sm = NewManagerForFirstStageMount(test_device);
 
     ASSERT_TRUE(new_sm->HandleImminentDataWipe());
     // Manually mount metadata so that we can call GetUpdateState() below.
@@ -1600,7 +1619,7 @@
     // Simulate a rollback, with reboot into recovery.
     auto test_device = new TestDeviceInfo(fake_super, "_a");
     test_device->set_recovery(true);
-    auto new_sm = SnapshotManager::NewForFirstStageMount(test_device);
+    auto new_sm = NewManagerForFirstStageMount(test_device);
 
     ASSERT_TRUE(new_sm->HandleImminentDataWipe());
     EXPECT_EQ(new_sm->GetUpdateState(), UpdateState::None);
@@ -1628,7 +1647,7 @@
     // Simulate a reboot into recovery.
     auto test_device = new TestDeviceInfo(fake_super, "_b");
     test_device->set_recovery(true);
-    auto new_sm = SnapshotManager::NewForFirstStageMount(test_device);
+    auto new_sm = NewManagerForFirstStageMount(test_device);
 
     ASSERT_TRUE(new_sm->HandleImminentDataWipe());
     // Manually mount metadata so that we can call GetUpdateState() below.
@@ -1639,7 +1658,7 @@
 
     // Now reboot into new slot.
     test_device = new TestDeviceInfo(fake_super, "_b");
-    auto init = SnapshotManager::NewForFirstStageMount(test_device);
+    auto init = NewManagerForFirstStageMount(test_device);
     ASSERT_TRUE(init->CreateLogicalAndSnapshotPartitions("super", snapshot_timeout_));
     // Verify that we are on the downgraded build.
     for (const auto& name : {"sys_b", "vnd_b", "prd_b"}) {
@@ -1685,7 +1704,7 @@
     ASSERT_TRUE(UnmapAll());
 
     // After reboot, init does first stage mount.
-    auto init = SnapshotManager::NewForFirstStageMount(new TestDeviceInfo(fake_super, "_b"));
+    auto init = NewManagerForFirstStageMount("_b");
     ASSERT_NE(init, nullptr);
     ASSERT_TRUE(init->NeedSnapshotsInFirstStageMount());
     ASSERT_TRUE(init->CreateLogicalAndSnapshotPartitions("super", snapshot_timeout_));
@@ -1773,7 +1792,7 @@
     ASSERT_TRUE(sm->FinishedSnapshotWrites(false));
     ASSERT_TRUE(UnmapAll());
 
-    auto init = SnapshotManager::NewForFirstStageMount(new TestDeviceInfo(fake_super, "_b"));
+    auto init = NewManagerForFirstStageMount("_b");
     ASSERT_NE(init, nullptr);
 
     ASSERT_TRUE(init->EnsureSnapuserdConnected());
@@ -1785,7 +1804,7 @@
     ASSERT_EQ(access("/dev/dm-user/sys_b-user-cow-init", F_OK), 0);
     ASSERT_EQ(access("/dev/dm-user/sys_b-user-cow", F_OK), -1);
 
-    ASSERT_TRUE(init->PerformSecondStageTransition());
+    ASSERT_TRUE(init->PerformInitTransition(SnapshotManager::InitTransition::SECOND_STAGE));
 
     // The control device should have been renamed.
     ASSERT_TRUE(android::fs_mgr::WaitForFileDeleted("/dev/dm-user/sys_b-user-cow-init", 10s));
@@ -1887,8 +1906,7 @@
     ASSERT_TRUE(UnmapAll());
 
     // Simulate reboot. After reboot, init does first stage mount.
-    auto init = SnapshotManager::NewForFirstStageMount(
-            new TestDeviceInfo(fake_super, flashed_slot_suffix));
+    auto init = NewManagerForFirstStageMount(flashed_slot_suffix);
     ASSERT_NE(init, nullptr);
 
     if (flashed_slot && after_merge) {
diff --git a/init/Android.bp b/init/Android.bp
index 19ba21b..cd295cf 100644
--- a/init/Android.bp
+++ b/init/Android.bp
@@ -60,6 +60,7 @@
     "selabel.cpp",
     "selinux.cpp",
     "sigchld_handler.cpp",
+    "snapuserd_transition.cpp",
     "switch_root.cpp",
     "uevent_listener.cpp",
     "ueventd.cpp",
diff --git a/init/Android.mk b/init/Android.mk
index c881e2f..561d641 100644
--- a/init/Android.mk
+++ b/init/Android.mk
@@ -57,6 +57,8 @@
     reboot_utils.cpp \
     selabel.cpp \
     selinux.cpp \
+    service_utils.cpp \
+    snapuserd_transition.cpp \
     switch_root.cpp \
     uevent_listener.cpp \
     util.cpp \
diff --git a/init/block_dev_initializer.h b/init/block_dev_initializer.h
index 79fe4ec..ec39ce0 100644
--- a/init/block_dev_initializer.h
+++ b/init/block_dev_initializer.h
@@ -12,6 +12,8 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#pragma once
+
 #include <memory>
 #include <set>
 #include <string>
diff --git a/init/first_stage_init.cpp b/init/first_stage_init.cpp
index 0949fc5..6954c03 100644
--- a/init/first_stage_init.cpp
+++ b/init/first_stage_init.cpp
@@ -42,6 +42,7 @@
 #include "first_stage_mount.h"
 #include "reboot_utils.h"
 #include "second_stage_resources.h"
+#include "snapuserd_transition.h"
 #include "switch_root.h"
 #include "util.h"
 
@@ -90,6 +91,12 @@
                     }
                 }
             }
+        } else if (de->d_type == DT_REG) {
+            // Do not free snapuserd if we will need the ramdisk copy during the
+            // selinux transition.
+            if (de->d_name == "snapuserd"s && IsFirstStageSnapuserdRunning()) {
+                continue;
+            }
         }
         unlinkat(dfd, de->d_name, is_dir ? AT_REMOVEDIR : 0);
     }
diff --git a/init/first_stage_mount.cpp b/init/first_stage_mount.cpp
index a0511cc..7c46281 100644
--- a/init/first_stage_mount.cpp
+++ b/init/first_stage_mount.cpp
@@ -44,6 +44,7 @@
 
 #include "block_dev_initializer.h"
 #include "devices.h"
+#include "snapuserd_transition.h"
 #include "switch_root.h"
 #include "uevent.h"
 #include "uevent_listener.h"
@@ -87,6 +88,7 @@
   protected:
     bool InitRequiredDevices(std::set<std::string> devices);
     bool CreateLogicalPartitions();
+    bool CreateSnapshotPartitions(android::snapshot::SnapshotManager* sm);
     bool MountPartition(const Fstab::iterator& begin, bool erase_same_mounts,
                         Fstab::iterator* end = nullptr);
 
@@ -109,6 +111,7 @@
 
     bool need_dm_verity_;
     bool dsu_not_on_userdata_ = false;
+    bool use_snapuserd_ = false;
 
     Fstab fstab_;
     // The super path is only set after InitDevices, and is invalid before.
@@ -338,21 +341,7 @@
             return false;
         }
         if (sm->NeedSnapshotsInFirstStageMount()) {
-            // When COW images are present for snapshots, they are stored on
-            // the data partition.
-            if (!InitRequiredDevices({"userdata"})) {
-                return false;
-            }
-            sm->SetUeventRegenCallback([this](const std::string& device) -> bool {
-                if (android::base::StartsWith(device, "/dev/block/dm-")) {
-                    return block_dev_init_.InitDmDevice(device);
-                }
-                if (android::base::StartsWith(device, "/dev/dm-user/")) {
-                    return block_dev_init_.InitDmUser(android::base::Basename(device));
-                }
-                return block_dev_init_.InitDevices({device});
-            });
-            return sm->CreateLogicalAndSnapshotPartitions(super_path_);
+            return CreateSnapshotPartitions(sm.get());
         }
     }
 
@@ -367,6 +356,37 @@
     return android::fs_mgr::CreateLogicalPartitions(*metadata.get(), super_path_);
 }
 
+bool FirstStageMount::CreateSnapshotPartitions(SnapshotManager* sm) {
+    // When COW images are present for snapshots, they are stored on
+    // the data partition.
+    if (!InitRequiredDevices({"userdata"})) {
+        return false;
+    }
+
+    use_snapuserd_ = sm->IsSnapuserdRequired();
+    if (use_snapuserd_) {
+        LaunchFirstStageSnapuserd();
+    }
+
+    sm->SetUeventRegenCallback([this](const std::string& device) -> bool {
+        if (android::base::StartsWith(device, "/dev/block/dm-")) {
+            return block_dev_init_.InitDmDevice(device);
+        }
+        if (android::base::StartsWith(device, "/dev/dm-user/")) {
+            return block_dev_init_.InitDmUser(android::base::Basename(device));
+        }
+        return block_dev_init_.InitDevices({device});
+    });
+    if (!sm->CreateLogicalAndSnapshotPartitions(super_path_)) {
+        return false;
+    }
+
+    if (use_snapuserd_) {
+        CleanupSnapuserdSocket();
+    }
+    return true;
+}
+
 bool FirstStageMount::MountPartition(const Fstab::iterator& begin, bool erase_same_mounts,
                                      Fstab::iterator* end) {
     // Sets end to begin + 1, so we can just return on failure below.
@@ -466,6 +486,10 @@
 
     if (system_partition == fstab_.end()) return true;
 
+    if (use_snapuserd_) {
+        SaveRamdiskPathToSnapuserd();
+    }
+
     if (MountPartition(system_partition, false /* erase_same_mounts */)) {
         if (dsu_not_on_userdata_ && fs_mgr_verity_is_check_at_most_once(*system_partition)) {
             LOG(ERROR) << "check_most_at_once forbidden on external media";
diff --git a/init/init.cpp b/init/init.cpp
index d1998a5..d6753f5 100644
--- a/init/init.cpp
+++ b/init/init.cpp
@@ -79,6 +79,7 @@
 #include "service.h"
 #include "service_parser.h"
 #include "sigchld_handler.h"
+#include "snapuserd_transition.h"
 #include "subcontext.h"
 #include "system/core/init/property_service.pb.h"
 #include "util.h"
@@ -741,10 +742,15 @@
         return {};
     }
     svc->Start();
+    svc->SetShutdownCritical();
 
-    if (!sm->PerformSecondStageTransition()) {
+    if (!sm->PerformSecondStageInitTransition()) {
         LOG(FATAL) << "Failed to transition snapuserd to second-stage";
     }
+
+    if (auto pid = GetSnapuserdFirstStagePid()) {
+        KillFirstStageSnapuserd(pid.value());
+    }
     return {};
 }
 
diff --git a/init/selinux.cpp b/init/selinux.cpp
index f03ca6b..0336936 100644
--- a/init/selinux.cpp
+++ b/init/selinux.cpp
@@ -45,7 +45,7 @@
 // 2) If these hashes do not match, then either /system or /system_ext or /product (or some of them)
 //    have been updated out of sync with /vendor (or /odm if it is present) and the init needs to
 //    compile the SEPolicy.  /system contains the SEPolicy compiler, secilc, and it is used by the
-//    LoadSplitPolicy() function below to compile the SEPolicy to a temp directory and load it.
+//    OpenSplitPolicy() function below to compile the SEPolicy to a temp directory and load it.
 //    That function contains even more documentation with the specific implementation details of how
 //    the SEPolicy is compiled if needed.
 
@@ -74,6 +74,7 @@
 #include "block_dev_initializer.h"
 #include "debug_ramdisk.h"
 #include "reboot_utils.h"
+#include "snapuserd_transition.h"
 #include "util.h"
 
 using namespace std::string_literals;
@@ -298,7 +299,12 @@
     return access(plat_policy_cil_file, R_OK) != -1;
 }
 
-bool LoadSplitPolicy() {
+struct PolicyFile {
+    unique_fd fd;
+    std::string path;
+};
+
+bool OpenSplitPolicy(PolicyFile* policy_file) {
     // IMPLEMENTATION NOTE: Split policy consists of three CIL files:
     // * platform -- policy needed due to logic contained in the system image,
     // * non-platform -- policy needed due to logic contained in the vendor image,
@@ -325,10 +331,8 @@
     if (!use_userdebug_policy && FindPrecompiledSplitPolicy(&precompiled_sepolicy_file)) {
         unique_fd fd(open(precompiled_sepolicy_file.c_str(), O_RDONLY | O_CLOEXEC | O_BINARY));
         if (fd != -1) {
-            if (selinux_android_load_policy_from_fd(fd, precompiled_sepolicy_file.c_str()) < 0) {
-                LOG(ERROR) << "Failed to load SELinux policy from " << precompiled_sepolicy_file;
-                return false;
-            }
+            policy_file->fd = std::move(fd);
+            policy_file->path = std::move(precompiled_sepolicy_file);
             return true;
         }
     }
@@ -446,34 +450,39 @@
     }
     unlink(compiled_sepolicy);
 
-    LOG(INFO) << "Loading compiled SELinux policy";
-    if (selinux_android_load_policy_from_fd(compiled_sepolicy_fd, compiled_sepolicy) < 0) {
-        LOG(ERROR) << "Failed to load SELinux policy from " << compiled_sepolicy;
-        return false;
-    }
-
+    policy_file->fd = std::move(compiled_sepolicy_fd);
+    policy_file->path = compiled_sepolicy;
     return true;
 }
 
-bool LoadMonolithicPolicy() {
-    LOG(VERBOSE) << "Loading SELinux policy from monolithic file";
-    if (selinux_android_load_policy() < 0) {
-        PLOG(ERROR) << "Failed to load monolithic SELinux policy";
+bool OpenMonolithicPolicy(PolicyFile* policy_file) {
+    static constexpr char kSepolicyFile[] = "/sepolicy";
+
+    LOG(VERBOSE) << "Opening SELinux policy from monolithic file";
+    policy_file->fd.reset(open(kSepolicyFile, O_RDONLY | O_CLOEXEC | O_NOFOLLOW));
+    if (policy_file->fd < 0) {
+        PLOG(ERROR) << "Failed to open monolithic SELinux policy";
         return false;
     }
+    policy_file->path = kSepolicyFile;
     return true;
 }
 
-bool LoadPolicy() {
-    return IsSplitPolicyDevice() ? LoadSplitPolicy() : LoadMonolithicPolicy();
-}
+void ReadPolicy(std::string* policy) {
+    PolicyFile policy_file;
 
-void SelinuxInitialize() {
-    LOG(INFO) << "Loading SELinux policy";
-    if (!LoadPolicy()) {
-        LOG(FATAL) << "Unable to load SELinux policy";
+    bool ok = IsSplitPolicyDevice() ? OpenSplitPolicy(&policy_file)
+                                    : OpenMonolithicPolicy(&policy_file);
+    if (!ok) {
+        LOG(FATAL) << "Unable to open SELinux policy";
     }
 
+    if (!android::base::ReadFdToString(policy_file.fd, policy)) {
+        PLOG(FATAL) << "Failed to read policy file: " << policy_file.path;
+    }
+}
+
+void SelinuxSetEnforcement() {
     bool kernel_enforcing = (security_getenforce() == 1);
     bool is_enforcing = IsEnforcing();
     if (kernel_enforcing != is_enforcing) {
@@ -670,6 +679,30 @@
     }
 }
 
+static void LoadSelinuxPolicy(std::string& policy) {
+    LOG(INFO) << "Loading SELinux policy";
+
+    set_selinuxmnt("/sys/fs/selinux");
+    if (security_load_policy(policy.data(), policy.size()) < 0) {
+        PLOG(FATAL) << "SELinux:  Could not load policy";
+    }
+}
+
+// The SELinux setup process is carefully orchestrated around snapuserd. Policy
+// must be loaded off dynamic partitions, and during an OTA, those partitions
+// cannot be read without snapuserd. But, with kernel-privileged snapuserd
+// running, loading the policy will immediately trigger audits.
+//
+// We use a five-step process to address this:
+//  (1) Read the policy into a string, with snapuserd running.
+//  (2) Rewrite the snapshot device-mapper tables, to generate new dm-user
+//      devices and to flush I/O.
+//  (3) Kill snapuserd, which no longer has any dm-user devices to attach to.
+//  (4) Load the sepolicy and issue critical restorecons in /dev, carefully
+//      avoiding anything that would read from /system.
+//  (5) Re-launch snapuserd and attach it to the dm-user devices from step (2).
+//
+// After this sequence, it is safe to enable enforcing mode and continue booting.
 int SetupSelinux(char** argv) {
     SetStdioToDevNull(argv);
     InitKernelLogging(argv);
@@ -682,9 +715,31 @@
 
     MountMissingSystemPartitions();
 
-    // Set up SELinux, loading the SELinux policy.
     SelinuxSetupKernelLogging();
-    SelinuxInitialize();
+
+    LOG(INFO) << "Opening SELinux policy";
+
+    // Read the policy before potentially killing snapuserd.
+    std::string policy;
+    ReadPolicy(&policy);
+
+    auto snapuserd_helper = SnapuserdSelinuxHelper::CreateIfNeeded();
+    if (snapuserd_helper) {
+        // Kill the old snapused to avoid audit messages. After this we cannot
+        // read from /system (or other dynamic partitions) until we call
+        // FinishTransition().
+        snapuserd_helper->StartTransition();
+    }
+
+    LoadSelinuxPolicy(policy);
+
+    if (snapuserd_helper) {
+        // Before enforcing, finish the pending snapuserd transition.
+        snapuserd_helper->FinishTransition();
+        snapuserd_helper = nullptr;
+    }
+
+    SelinuxSetEnforcement();
 
     // We're in the kernel domain and want to transition to the init domain.  File systems that
     // store SELabels in their xattrs, such as ext4 do not need an explicit restorecon here,
diff --git a/init/snapuserd_transition.cpp b/init/snapuserd_transition.cpp
new file mode 100644
index 0000000..19b5c57
--- /dev/null
+++ b/init/snapuserd_transition.cpp
@@ -0,0 +1,305 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "snapuserd_transition.h"
+
+#include <sys/mman.h>
+#include <sys/socket.h>
+#include <sys/syscall.h>
+#include <sys/xattr.h>
+#include <unistd.h>
+
+#include <filesystem>
+#include <string>
+
+#include <android-base/file.h>
+#include <android-base/logging.h>
+#include <android-base/parseint.h>
+#include <android-base/strings.h>
+#include <android-base/unique_fd.h>
+#include <cutils/sockets.h>
+#include <libsnapshot/snapshot.h>
+#include <libsnapshot/snapuserd_client.h>
+#include <private/android_filesystem_config.h>
+#include <selinux/android.h>
+
+#include "block_dev_initializer.h"
+#include "service_utils.h"
+#include "util.h"
+
+namespace android {
+namespace init {
+
+using namespace std::string_literals;
+
+using android::base::unique_fd;
+using android::snapshot::SnapshotManager;
+using android::snapshot::SnapuserdClient;
+
+static constexpr char kSnapuserdPath[] = "/system/bin/snapuserd";
+static constexpr char kSnapuserdFirstStagePidVar[] = "FIRST_STAGE_SNAPUSERD_PID";
+static constexpr char kSnapuserdFirstStageFdVar[] = "FIRST_STAGE_SNAPUSERD_FD";
+static constexpr char kSnapuserdLabel[] = "u:object_r:snapuserd_exec:s0";
+static constexpr char kSnapuserdSocketLabel[] = "u:object_r:snapuserd_socket:s0";
+
+void LaunchFirstStageSnapuserd() {
+    SocketDescriptor socket_desc;
+    socket_desc.name = android::snapshot::kSnapuserdSocket;
+    socket_desc.type = SOCK_STREAM;
+    socket_desc.perm = 0660;
+    socket_desc.uid = AID_SYSTEM;
+    socket_desc.gid = AID_SYSTEM;
+
+    // We specify a label here even though it technically is not needed. During
+    // first_stage_mount there is no sepolicy loaded. Once sepolicy is loaded,
+    // we bypass the socket entirely.
+    auto socket = socket_desc.Create(kSnapuserdSocketLabel);
+    if (!socket.ok()) {
+        LOG(FATAL) << "Could not create snapuserd socket: " << socket.error();
+    }
+
+    pid_t pid = fork();
+    if (pid < 0) {
+        PLOG(FATAL) << "Cannot launch snapuserd; fork failed";
+    }
+    if (pid == 0) {
+        socket->Publish();
+        char arg0[] = "/system/bin/snapuserd";
+        char* const argv[] = {arg0, nullptr};
+        if (execv(arg0, argv) < 0) {
+            PLOG(FATAL) << "Cannot launch snapuserd; execv failed";
+        }
+        _exit(127);
+    }
+
+    setenv(kSnapuserdFirstStagePidVar, std::to_string(pid).c_str(), 1);
+
+    LOG(INFO) << "Relaunched snapuserd with pid: " << pid;
+}
+
+std::optional<pid_t> GetSnapuserdFirstStagePid() {
+    const char* pid_str = getenv(kSnapuserdFirstStagePidVar);
+    if (!pid_str) {
+        return {};
+    }
+
+    int pid = 0;
+    if (!android::base::ParseInt(pid_str, &pid)) {
+        LOG(FATAL) << "Could not parse pid in environment, " << kSnapuserdFirstStagePidVar << "="
+                   << pid_str;
+    }
+    return {pid};
+}
+
+static void RelabelLink(const std::string& link) {
+    selinux_android_restorecon(link.c_str(), 0);
+
+    std::string path;
+    if (android::base::Readlink(link, &path)) {
+        selinux_android_restorecon(path.c_str(), 0);
+    }
+}
+
+static void RelabelDeviceMapper() {
+    selinux_android_restorecon("/dev/device-mapper", 0);
+
+    std::error_code ec;
+    for (auto& iter : std::filesystem::directory_iterator("/dev/block", ec)) {
+        const auto& path = iter.path();
+        if (android::base::StartsWith(path.string(), "/dev/block/dm-")) {
+            selinux_android_restorecon(path.string().c_str(), 0);
+        }
+    }
+}
+
+static std::optional<int> GetRamdiskSnapuserdFd() {
+    const char* fd_str = getenv(kSnapuserdFirstStageFdVar);
+    if (!fd_str) {
+        return {};
+    }
+
+    int fd;
+    if (!android::base::ParseInt(fd_str, &fd)) {
+        LOG(FATAL) << "Could not parse fd in environment, " << kSnapuserdFirstStageFdVar << "="
+                   << fd_str;
+    }
+    return {fd};
+}
+
+void RestoreconRamdiskSnapuserd(int fd) {
+    if (fsetxattr(fd, XATTR_NAME_SELINUX, kSnapuserdLabel, strlen(kSnapuserdLabel) + 1, 0) < 0) {
+        PLOG(FATAL) << "fsetxattr snapuserd failed";
+    }
+}
+
+SnapuserdSelinuxHelper::SnapuserdSelinuxHelper(std::unique_ptr<SnapshotManager>&& sm, pid_t old_pid)
+    : sm_(std::move(sm)), old_pid_(old_pid) {
+    // Only dm-user device names change during transitions, so the other
+    // devices are expected to be present.
+    sm_->SetUeventRegenCallback([this](const std::string& device) -> bool {
+        if (android::base::StartsWith(device, "/dev/dm-user/")) {
+            return block_dev_init_.InitDmUser(android::base::Basename(device));
+        }
+        return true;
+    });
+}
+
+void SnapuserdSelinuxHelper::StartTransition() {
+    LOG(INFO) << "Starting SELinux transition of snapuserd";
+
+    // The restorecon path reads from /system etc, so make sure any reads have
+    // been cached before proceeding.
+    auto handle = selinux_android_file_context_handle();
+    if (!handle) {
+        LOG(FATAL) << "Could not create SELinux file context handle";
+    }
+    selinux_android_set_sehandle(handle);
+
+    // We cannot access /system after the transition, so make sure init is
+    // pinned in memory.
+    if (mlockall(MCL_CURRENT) < 0) {
+        LOG(FATAL) << "mlockall failed";
+    }
+
+    argv_.emplace_back("snapuserd");
+    argv_.emplace_back("-no_socket");
+    if (!sm_->DetachSnapuserdForSelinux(&argv_)) {
+        LOG(FATAL) << "Could not perform selinux transition";
+    }
+
+    // Make sure the process is gone so we don't have any selinux audits.
+    KillFirstStageSnapuserd(old_pid_);
+}
+
+void SnapuserdSelinuxHelper::FinishTransition() {
+    RelabelLink("/dev/block/by-name/super");
+    RelabelDeviceMapper();
+
+    selinux_android_restorecon("/dev/null", 0);
+    selinux_android_restorecon("/dev/urandom", 0);
+    selinux_android_restorecon("/dev/kmsg", 0);
+    selinux_android_restorecon("/dev/dm-user", SELINUX_ANDROID_RESTORECON_RECURSE);
+
+    RelaunchFirstStageSnapuserd();
+
+    if (munlockall() < 0) {
+        PLOG(ERROR) << "munlockall failed";
+    }
+}
+
+void SnapuserdSelinuxHelper::RelaunchFirstStageSnapuserd() {
+    auto fd = GetRamdiskSnapuserdFd();
+    if (!fd) {
+        LOG(FATAL) << "Environment variable " << kSnapuserdFirstStageFdVar << " was not set!";
+    }
+    unsetenv(kSnapuserdFirstStageFdVar);
+
+    RestoreconRamdiskSnapuserd(fd.value());
+
+    pid_t pid = fork();
+    if (pid < 0) {
+        PLOG(FATAL) << "Fork to relaunch snapuserd failed";
+    }
+    if (pid > 0) {
+        // We don't need the descriptor anymore, and it should be closed to
+        // avoid leaking into subprocesses.
+        close(fd.value());
+
+        setenv(kSnapuserdFirstStagePidVar, std::to_string(pid).c_str(), 1);
+
+        LOG(INFO) << "Relaunched snapuserd with pid: " << pid;
+        return;
+    }
+
+    // Make sure the descriptor is gone after we exec.
+    if (fcntl(fd.value(), F_SETFD, FD_CLOEXEC) < 0) {
+        PLOG(FATAL) << "fcntl FD_CLOEXEC failed for snapuserd fd";
+    }
+
+    std::vector<char*> argv;
+    for (auto& arg : argv_) {
+        argv.emplace_back(arg.data());
+    }
+    argv.emplace_back(nullptr);
+
+    int rv = syscall(SYS_execveat, fd.value(), "", reinterpret_cast<char* const*>(argv.data()),
+                     nullptr, AT_EMPTY_PATH);
+    if (rv < 0) {
+        PLOG(FATAL) << "Failed to execveat() snapuserd";
+    }
+}
+
+std::unique_ptr<SnapuserdSelinuxHelper> SnapuserdSelinuxHelper::CreateIfNeeded() {
+    if (IsRecoveryMode()) {
+        return nullptr;
+    }
+
+    auto old_pid = GetSnapuserdFirstStagePid();
+    if (!old_pid) {
+        return nullptr;
+    }
+
+    auto sm = SnapshotManager::NewForFirstStageMount();
+    if (!sm) {
+        LOG(FATAL) << "Unable to create SnapshotManager";
+    }
+    return std::make_unique<SnapuserdSelinuxHelper>(std::move(sm), old_pid.value());
+}
+
+void KillFirstStageSnapuserd(pid_t pid) {
+    if (kill(pid, SIGTERM) < 0 && errno != ESRCH) {
+        LOG(ERROR) << "Kill snapuserd pid failed: " << pid;
+    } else {
+        LOG(INFO) << "Sent SIGTERM to snapuserd process " << pid;
+    }
+}
+
+void CleanupSnapuserdSocket() {
+    auto socket_path = ANDROID_SOCKET_DIR "/"s + android::snapshot::kSnapuserdSocket;
+    if (access(socket_path.c_str(), F_OK) != 0) {
+        return;
+    }
+
+    // Tell the daemon to stop accepting connections and to gracefully exit
+    // once all outstanding handlers have terminated.
+    if (auto client = SnapuserdClient::Connect(android::snapshot::kSnapuserdSocket, 3s)) {
+        client->DetachSnapuserd();
+    }
+
+    // Unlink the socket so we can create it again in second-stage.
+    if (unlink(socket_path.c_str()) < 0) {
+        PLOG(FATAL) << "unlink " << socket_path << " failed";
+    }
+}
+
+void SaveRamdiskPathToSnapuserd() {
+    int fd = open(kSnapuserdPath, O_PATH);
+    if (fd < 0) {
+        PLOG(FATAL) << "Unable to open snapuserd: " << kSnapuserdPath;
+    }
+
+    auto value = std::to_string(fd);
+    if (setenv(kSnapuserdFirstStageFdVar, value.c_str(), 1) < 0) {
+        PLOG(FATAL) << "setenv failed: " << kSnapuserdFirstStageFdVar << "=" << value;
+    }
+}
+
+bool IsFirstStageSnapuserdRunning() {
+    return GetSnapuserdFirstStagePid().has_value();
+}
+
+}  // namespace init
+}  // namespace android
diff --git a/init/snapuserd_transition.h b/init/snapuserd_transition.h
new file mode 100644
index 0000000..a5ab652
--- /dev/null
+++ b/init/snapuserd_transition.h
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2020 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <sys/types.h>
+
+#include <optional>
+#include <string>
+#include <vector>
+
+#include <libsnapshot/snapshot.h>
+
+#include "block_dev_initializer.h"
+
+namespace android {
+namespace init {
+
+// Fork and exec a new copy of snapuserd.
+void LaunchFirstStageSnapuserd();
+
+class SnapuserdSelinuxHelper final {
+    using SnapshotManager = android::snapshot::SnapshotManager;
+
+  public:
+    SnapuserdSelinuxHelper(std::unique_ptr<SnapshotManager>&& sm, pid_t old_pid);
+
+    void StartTransition();
+    void FinishTransition();
+
+    // Return a helper for facilitating the selinux transition of snapuserd.
+    // If snapuserd is not in use, null is returned. StartTransition() should
+    // be called after reading policy. FinishTransition() should be called
+    // after loading policy. In between, no reads of /system or other dynamic
+    // partitions are possible.
+    static std::unique_ptr<SnapuserdSelinuxHelper> CreateIfNeeded();
+
+  private:
+    void RelaunchFirstStageSnapuserd();
+    void ExecSnapuserd();
+
+    std::unique_ptr<SnapshotManager> sm_;
+    BlockDevInitializer block_dev_init_;
+    pid_t old_pid_;
+    std::vector<std::string> argv_;
+};
+
+// Remove /dev/socket/snapuserd. This ensures that (1) the existing snapuserd
+// will receive no new requests, and (2) the next copy we transition to can
+// own the socket.
+void CleanupSnapuserdSocket();
+
+// Kill an instance of snapuserd given a pid.
+void KillFirstStageSnapuserd(pid_t pid);
+
+// Save an open fd to /system/bin (in the ramdisk) into an environment. This is
+// used to later execveat() snapuserd.
+void SaveRamdiskPathToSnapuserd();
+
+// Returns true if first-stage snapuserd is running.
+bool IsFirstStageSnapuserdRunning();
+
+// Return the pid of the first-stage instances of snapuserd, if it was started.
+std::optional<pid_t> GetSnapuserdFirstStagePid();
+
+// Save an open fd to /system/bin (in the ramdisk) into an environment. This is
+// used to later execveat() snapuserd.
+void SaveRamdiskPathToSnapuserd();
+
+// Returns true if first-stage snapuserd is running.
+bool IsFirstStageSnapuserdRunning();
+
+}  // namespace init
+}  // namespace android