Stop services when unload an apex

Bug: 238854102
Test: atest CtsInitTestCases ApexTestCases
Change-Id: I3b9df9424f7841c42bd1bde27cd0e0750615bd6c
diff --git a/init/Android.bp b/init/Android.bp
index 2dd9683..58f4a9e 100644
--- a/init/Android.bp
+++ b/init/Android.bp
@@ -217,6 +217,7 @@
         "selinux_policy_version",
     ],
     srcs: init_common_sources + init_device_sources,
+    export_include_dirs: ["."],
     generated_sources: [
         "apex-info-list",
     ],
@@ -246,6 +247,10 @@
             ],
         },
     },
+    visibility: [
+        "//system/apex/apexd",
+        "//frameworks/native/cmds/installd",
+    ],
 }
 
 phony {
diff --git a/init/init.cpp b/init/init.cpp
index 5f516b7..9192214 100644
--- a/init/init.cpp
+++ b/init/init.cpp
@@ -79,6 +79,7 @@
 #include "selabel.h"
 #include "selinux.h"
 #include "service.h"
+#include "service_list.h"
 #include "service_parser.h"
 #include "sigchld_handler.h"
 #include "snapuserd_transition.h"
@@ -443,11 +444,32 @@
     return {};
 }
 
+int StopServicesFromApex(const std::string& apex_name) {
+    auto services = ServiceList::GetInstance().FindServicesByApexName(apex_name);
+    if (services.empty()) {
+        LOG(INFO) << "No service found for APEX: " << apex_name;
+        return 0;
+    }
+    std::set<std::string> service_names;
+    for (const auto& service : services) {
+        service_names.emplace(service->name());
+    }
+    constexpr std::chrono::milliseconds kServiceStopTimeout = 10s;
+    int still_running = StopServicesAndLogViolations(service_names, kServiceStopTimeout,
+                        true /*SIGTERM*/);
+    // Send SIGKILL to ones that didn't terminate cleanly.
+    if (still_running > 0) {
+        still_running = StopServicesAndLogViolations(service_names, 0ms, false /*SIGKILL*/);
+    }
+    return still_running;
+}
+
 static Result<void> DoUnloadApex(const std::string& apex_name) {
-    std::string prop_name = "init.apex." + apex_name;
+    if (StopServicesFromApex(apex_name) > 0) {
+        return Error() << "Unable to stop all service from " << apex_name;
+    }
     // TODO(b/232114573) remove services and actions read from the apex
-    // TODO(b/232799709) kill services from the apex
-    SetProperty(prop_name, "unloaded");
+    SetProperty("init.apex." + apex_name, "unloaded");
     return {};
 }
 
@@ -471,14 +493,12 @@
 }
 
 static Result<void> DoLoadApex(const std::string& apex_name) {
-    std::string prop_name = "init.apex." + apex_name;
     // TODO(b/232799709) read .rc files from the apex
-
     if (auto result = UpdateApexLinkerConfig(apex_name); !result.ok()) {
         return result.error();
     }
 
-    SetProperty(prop_name, "loaded");
+    SetProperty("init.apex." + apex_name, "loaded");
     return {};
 }
 
diff --git a/init/init.h b/init/init.h
index 5220535..dd44e95 100644
--- a/init/init.h
+++ b/init/init.h
@@ -46,5 +46,7 @@
 
 int SecondStageMain(int argc, char** argv);
 
+int StopServicesFromApex(const std::string& apex_name);
+
 }  // namespace init
 }  // namespace android
diff --git a/init/init_test.cpp b/init/init_test.cpp
index 5651a83..e7218e8 100644
--- a/init/init_test.cpp
+++ b/init/init_test.cpp
@@ -15,11 +15,14 @@
  */
 
 #include <functional>
+#include <string_view>
+#include <type_traits>
 
 #include <android-base/file.h>
 #include <android-base/logging.h>
 #include <android-base/properties.h>
 #include <gtest/gtest.h>
+#include <selinux/selinux.h>
 
 #include "action.h"
 #include "action_manager.h"
@@ -27,6 +30,7 @@
 #include "builtin_arguments.h"
 #include "builtins.h"
 #include "import_parser.h"
+#include "init.h"
 #include "keyword_map.h"
 #include "parser.h"
 #include "service.h"
@@ -37,6 +41,7 @@
 using android::base::GetIntProperty;
 using android::base::GetProperty;
 using android::base::SetProperty;
+using android::base::StringReplace;
 using android::base::WaitForProperty;
 using namespace std::literals;
 
@@ -188,6 +193,186 @@
     EXPECT_TRUE(service->is_override());
 }
 
+static std::string GetSecurityContext() {
+    char* ctx;
+    if (getcon(&ctx) == -1) {
+        ADD_FAILURE() << "Failed to call getcon : " << strerror(errno);
+    }
+    std::string result = std::string(ctx);
+    freecon(ctx);
+    return result;
+}
+
+void TestStartApexServices(const std::vector<std::string>& service_names,
+        const std::string& apex_name) {
+    for (auto const& svc : service_names) {
+        auto service = ServiceList::GetInstance().FindService(svc);
+        ASSERT_NE(nullptr, service);
+        ASSERT_RESULT_OK(service->Start());
+        ASSERT_TRUE(service->IsRunning());
+        LOG(INFO) << "Service " << svc << " is running";
+        if (!apex_name.empty()) {
+            service->set_filename("/apex/" + apex_name + "/init_test.rc");
+        } else {
+            service->set_filename("");
+        }
+    }
+    if (!apex_name.empty()) {
+        auto apex_services = ServiceList::GetInstance().FindServicesByApexName(apex_name);
+        EXPECT_EQ(service_names.size(), apex_services.size());
+    }
+}
+
+void TestStopApexServices(const std::vector<std::string>& service_names, bool expect_to_run) {
+    for (auto const& svc : service_names) {
+        auto service = ServiceList::GetInstance().FindService(svc);
+        ASSERT_NE(nullptr, service);
+        EXPECT_EQ(expect_to_run, service->IsRunning());
+    }
+    ServiceList::GetInstance().RemoveServiceIf([&](const std::unique_ptr<Service>& s) -> bool {
+        if (std::find(service_names.begin(), service_names.end(), s->name())
+                != service_names.end()) {
+            return true;
+        }
+        return false;
+    });
+}
+
+void InitApexService(const std::string_view& init_template) {
+    std::string init_script = StringReplace(init_template, "$selabel",
+                                    GetSecurityContext(), true);
+
+    ActionManager action_manager;
+    TestInitText(init_script, BuiltinFunctionMap(), {}, &action_manager,
+            &ServiceList::GetInstance());
+}
+
+void TestApexServicesInit(const std::vector<std::string>& apex_services,
+            const std::vector<std::string>& other_apex_services,
+            const std::vector<std::string> non_apex_services) {
+    auto num_svc = apex_services.size() + other_apex_services.size() + non_apex_services.size();
+    ASSERT_EQ(static_cast<long>(num_svc), std::distance(ServiceList::GetInstance().begin(),
+            ServiceList::GetInstance().end()));
+
+    TestStartApexServices(apex_services, "com.android.apex.test_service");
+    TestStartApexServices(other_apex_services, "com.android.other_apex.test_service");
+    TestStartApexServices(non_apex_services, /*apex_anme=*/ "");
+
+    StopServicesFromApex("com.android.apex.test_service");
+    TestStopApexServices(apex_services, /*expect_to_run=*/ false);
+    TestStopApexServices(other_apex_services, /*expect_to_run=*/ true);
+    TestStopApexServices(non_apex_services, /*expect_to_run=*/ true);
+
+    ASSERT_EQ(0, std::distance(ServiceList::GetInstance().begin(),
+       ServiceList::GetInstance().end()));
+}
+
+TEST(init, StopServiceByApexName) {
+    std::string_view script_template = R"init(
+service apex_test_service /system/bin/yes
+    user shell
+    group shell
+    seclabel $selabel
+)init";
+    InitApexService(script_template);
+    TestApexServicesInit({"apex_test_service"}, {}, {});
+}
+
+TEST(init, StopMultipleServicesByApexName) {
+    std::string_view script_template = R"init(
+service apex_test_service_multiple_a /system/bin/yes
+    user shell
+    group shell
+    seclabel $selabel
+service apex_test_service_multiple_b /system/bin/id
+    user shell
+    group shell
+    seclabel $selabel
+)init";
+    InitApexService(script_template);
+    TestApexServicesInit({"apex_test_service_multiple_a",
+            "apex_test_service_multiple_b"}, {}, {});
+}
+
+TEST(init, StopServicesFromMultipleApexes) {
+    std::string_view apex_script_template = R"init(
+service apex_test_service_multi_apex_a /system/bin/yes
+    user shell
+    group shell
+    seclabel $selabel
+service apex_test_service_multi_apex_b /system/bin/id
+    user shell
+    group shell
+    seclabel $selabel
+)init";
+    InitApexService(apex_script_template);
+
+    std::string_view other_apex_script_template = R"init(
+service apex_test_service_multi_apex_c /system/bin/yes
+    user shell
+    group shell
+    seclabel $selabel
+)init";
+    InitApexService(other_apex_script_template);
+
+    TestApexServicesInit({"apex_test_service_multi_apex_a",
+            "apex_test_service_multi_apex_b"}, {"apex_test_service_multi_apex_c"}, {});
+}
+
+TEST(init, StopServicesFromApexAndNonApex) {
+    std::string_view apex_script_template = R"init(
+service apex_test_service_apex_a /system/bin/yes
+    user shell
+    group shell
+    seclabel $selabel
+service apex_test_service_apex_b /system/bin/id
+    user shell
+    group shell
+    seclabel $selabel
+)init";
+    InitApexService(apex_script_template);
+
+    std::string_view non_apex_script_template = R"init(
+service apex_test_service_non_apex /system/bin/yes
+    user shell
+    group shell
+    seclabel $selabel
+)init";
+    InitApexService(non_apex_script_template);
+
+    TestApexServicesInit({"apex_test_service_apex_a",
+            "apex_test_service_apex_b"}, {}, {"apex_test_service_non_apex"});
+}
+
+TEST(init, StopServicesFromApexMixed) {
+    std::string_view script_template = R"init(
+service apex_test_service_mixed_a /system/bin/yes
+    user shell
+    group shell
+    seclabel $selabel
+)init";
+    InitApexService(script_template);
+
+    std::string_view other_apex_script_template = R"init(
+service apex_test_service_mixed_b /system/bin/yes
+    user shell
+    group shell
+    seclabel $selabel
+)init";
+    InitApexService(other_apex_script_template);
+
+    std::string_view non_apex_script_template = R"init(
+service apex_test_service_mixed_c /system/bin/yes
+    user shell
+    group shell
+    seclabel $selabel
+)init";
+    InitApexService(non_apex_script_template);
+
+    TestApexServicesInit({"apex_test_service_mixed_a"},
+            {"apex_test_service_mixed_b"}, {"apex_test_service_mixed_c"});
+}
+
 TEST(init, EventTriggerOrderMultipleFiles) {
     // 6 total files, which should have their triggers executed in the following order:
     // 1: start - original script parsed
diff --git a/init/service.h b/init/service.h
index f7f32d9..7af3615 100644
--- a/init/service.h
+++ b/init/service.h
@@ -142,6 +142,8 @@
         }
     }
     Subcontext* subcontext() const { return subcontext_; }
+    const std::string& filename() const { return filename_; }
+    void set_filename(const std::string& name) { filename_ = name; }
 
   private:
     void NotifyStateChange(const std::string& new_state) const;
diff --git a/init/service_list.h b/init/service_list.h
index 555da25..33aaa5f 100644
--- a/init/service_list.h
+++ b/init/service_list.h
@@ -16,10 +16,14 @@
 
 #pragma once
 
+#include <iterator>
 #include <memory>
 #include <vector>
 
+#include <android-base/logging.h>
+
 #include "service.h"
+#include "util.h"
 
 namespace android {
 namespace init {
@@ -52,6 +56,17 @@
         return nullptr;
     }
 
+    std::vector<Service*> FindServicesByApexName(const std::string& apex_name) const {
+        CHECK(!apex_name.empty()) << "APEX name cannot be empty";
+        std::vector<Service*> matches;
+        for (const auto& svc : services_) {
+            if (GetApexNameFromFileName(svc->filename()) == apex_name) {
+                matches.emplace_back(svc.get());
+            }
+        }
+        return matches;
+    }
+
     Service* FindInterface(const std::string& interface_name) {
         for (const auto& svc : services_) {
             if (svc->interfaces().count(interface_name) > 0) {