Parallel Module Loading: Refactor threads pool and architecture

Instead of creating new threads in each run, reuse same threads to load the modules and we can keep the utilization of threads and cpufreq higher.

And let subthreads dedicated to loading modules in parallel or alone, main thread only needs to update independent module pool. this architecture should be simpler and more readable.

Bug: 402263126
Change-Id: If3f07a18cd7c43132ff799d32786d2466f3110f1
Signed-off-by: Chungkai Mei <chungkai@google.com>
diff --git a/init/first_stage_init.cpp b/init/first_stage_init.cpp
index 6bb0ad7..49af5ed 100644
--- a/init/first_stage_init.cpp
+++ b/init/first_stage_init.cpp
@@ -211,8 +211,8 @@
 }
 
 #define MODULE_BASE_DIR "/lib/modules"
-bool LoadKernelModules(BootMode boot_mode, bool want_console, bool want_parallel,
-                       int& modules_loaded) {
+bool LoadKernelModules(BootMode boot_mode, bool want_console,
+                       Modprobe::LoadParallelMode want_parallel_mode, int& modules_loaded) {
     struct utsname uts {};
     if (uname(&uts)) {
         LOG(FATAL) << "Failed to get kernel version.";
@@ -279,8 +279,9 @@
     }
 
     Modprobe m({MODULE_BASE_DIR}, GetModuleLoadList(boot_mode, MODULE_BASE_DIR));
-    bool retval = (want_parallel) ? m.LoadModulesParallel(std::thread::hardware_concurrency())
-                                  : m.LoadListedModules(!want_console);
+    bool retval = (want_parallel_mode != Modprobe::LoadParallelMode::NONE) ?
+            m.LoadModulesParallel(std::thread::hardware_concurrency(), want_parallel_mode) :
+            m.LoadListedModules(!want_console);
     modules_loaded = m.GetModuleCount();
     if (modules_loaded > 0) {
         LOG(INFO) << "Loaded " << modules_loaded << " modules from " << MODULE_BASE_DIR;
@@ -437,14 +438,15 @@
     }
 
     auto want_console = ALLOW_FIRST_STAGE_CONSOLE ? FirstStageConsole(cmdline, bootconfig) : 0;
-    auto want_parallel =
-            bootconfig.find("androidboot.load_modules_parallel = \"true\"") != std::string::npos;
+    auto want_parallel_mode = Modprobe::LoadParallelMode::NONE;
+    if (bootconfig.find("androidboot.load_modules_parallel = \"true\"")
+        != std::string::npos)
+        want_parallel_mode = Modprobe::LoadParallelMode::NORMAL;
 
     boot_clock::time_point module_start_time = boot_clock::now();
     int module_count = 0;
     BootMode boot_mode = GetBootMode(cmdline, bootconfig);
-    if (!LoadKernelModules(boot_mode, want_console,
-                           want_parallel, module_count)) {
+    if (!LoadKernelModules(boot_mode, want_console, want_parallel_mode, module_count)) {
         if (want_console != FirstStageConsoleParam::DISABLED) {
             LOG(ERROR) << "Failed to load kernel modules, starting console";
         } else {
diff --git a/libmodprobe/include/modprobe/modprobe.h b/libmodprobe/include/modprobe/modprobe.h
index d33e17d..de9dcd2 100644
--- a/libmodprobe/include/modprobe/modprobe.h
+++ b/libmodprobe/include/modprobe/modprobe.h
@@ -28,10 +28,15 @@
 
 class Modprobe {
   public:
+    enum LoadParallelMode {
+      NONE = 0,
+      NORMAL,
+    };
+
     Modprobe(const std::vector<std::string>&, const std::string load_file = "modules.load",
              bool use_blocklist = true);
 
-    bool LoadModulesParallel(int num_threads);
+    bool LoadModulesParallel(int num_threads, int mode);
     bool LoadListedModules(bool strict = true);
     bool LoadWithAliases(const std::string& module_name, bool strict,
                          const std::string& parameters = "");
@@ -45,6 +50,7 @@
     bool IsBlocklisted(const std::string& module_name);
 
   private:
+    bool IsLoadSequential(const std::string& module);
     std::string MakeCanonical(const std::string& module_path);
     bool InsmodWithDeps(const std::string& module_name, const std::string& parameters);
     bool Insmod(const std::string& path_name, const std::string& parameters);
diff --git a/libmodprobe/libmodprobe.cpp b/libmodprobe/libmodprobe.cpp
index bdd114c..1b524e9 100644
--- a/libmodprobe/libmodprobe.cpp
+++ b/libmodprobe/libmodprobe.cpp
@@ -24,6 +24,7 @@
 #include <sys/wait.h>
 
 #include <algorithm>
+#include <condition_variable>
 #include <map>
 #include <set>
 #include <string>
@@ -506,9 +507,18 @@
 // repeat these steps until all modules are loaded.
 // Discard all blocklist.
 // Softdeps are taken care in InsmodWithDeps().
-bool Modprobe::LoadModulesParallel(int num_threads) {
-    bool ret = true;
-    std::map<std::string, std::vector<std::string>> mod_with_deps;
+bool Modprobe::LoadModulesParallel(int num_threads, int mode) {
+    std::map<std::string, std::vector<std::string>> mods_with_deps;
+    std::unordered_set<std::string> mods_loading;
+    std::vector<std::string> parallel_modules, sequential_modules;
+    std::vector<std::thread> threads;
+    std::atomic<bool> ret(true);
+    std::atomic<bool> finish(false);
+    std::mutex mods_to_load_lock;
+    std::condition_variable cv_update_module, cv_load_module;
+    int sleeping_threads = 0;
+
+    LOG(INFO) << "LoadParallelMode:" << mode;
 
     // Get dependencies
     for (const auto& module : module_load_) {
@@ -517,22 +527,59 @@
             LOG(VERBOSE) << "LMP: Blocklist: Module " << module << " skipping...";
             continue;
         }
+
         auto dependencies = GetDependencies(MakeCanonical(module));
         if (dependencies.empty()) {
             LOG(ERROR) << "LMP: Hard-dep: Module " << module
                        << " not in .dep file";
             return false;
         }
-        mod_with_deps[MakeCanonical(module)] = dependencies;
+
+        mods_with_deps[MakeCanonical(module)] = dependencies;
     }
 
-    while (!mod_with_deps.empty()) {
-        std::vector<std::thread> threads;
-        std::vector<std::string> mods_path_to_load;
-        std::mutex vector_lock;
+    // Consumers load modules in parallel or sequentially
+    auto thread_function = [&] {
+        while (!mods_with_deps.empty() && ret.load()) {
+            std::unique_lock<std::mutex> lock(mods_to_load_lock);
 
-        // Find independent modules
-        for (const auto& [it_mod, it_dep] : mod_with_deps) {
+            if (sequential_modules.empty() && parallel_modules.empty()) {
+                sleeping_threads++;
+
+                if (mode == LoadParallelMode::NORMAL && sleeping_threads == num_threads)
+                    cv_update_module.notify_one();
+
+                cv_load_module.wait(lock, [&](){
+                    return !parallel_modules.empty() ||
+                           !sequential_modules.empty() ||
+                           finish.load(); });
+
+                sleeping_threads--;
+            }
+
+            while (!sequential_modules.empty()) {
+                auto mod_to_load = std::move(sequential_modules.back());
+                sequential_modules.pop_back();
+                ret.store(ret.load() && LoadWithAliases(mod_to_load, true));
+            }
+
+            if (!parallel_modules.empty()) {
+                auto mod_to_load = std::move(parallel_modules.back());
+                parallel_modules.pop_back();
+
+                lock.unlock();
+                ret.store(ret.load() && LoadWithAliases(mod_to_load, true));
+            }
+        }
+    };
+
+    std::generate_n(std::back_inserter(threads), num_threads,
+        [&] { return std::thread(thread_function); });
+
+    // Producer check there's any independent module
+    while (!mods_with_deps.empty()) {
+        std::unique_lock<std::mutex> lock(mods_to_load_lock);
+        for (const auto& [it_mod, it_dep] : mods_with_deps) {
             auto itd_last = it_dep.rbegin();
             if (itd_last == it_dep.rend())
                 continue;
@@ -542,69 +589,51 @@
             if (IsBlocklisted(cnd_last)) {
                 LOG(ERROR) << "LMP: Blocklist: Module-dep " << cnd_last
                            << " : failed to load module " << it_mod;
-                return false;
+                ret.store(0);
+                break;
             }
 
-            std::string str = "load_sequential=1";
-            auto it = module_options_[cnd_last].find(str);
-            if (it != std::string::npos) {
-                module_options_[cnd_last].erase(it, it + str.size());
+            if (mods_loading.find(cnd_last) == mods_loading.end()) {
+                mods_loading.insert(cnd_last);
 
-                if (!LoadWithAliases(cnd_last, true)) {
-                    return false;
-                }
-            } else {
-                if (std::find(mods_path_to_load.begin(), mods_path_to_load.end(),
-                            cnd_last) == mods_path_to_load.end()) {
-                    mods_path_to_load.emplace_back(cnd_last);
-                }
+                if (IsLoadSequential(cnd_last))
+                    sequential_modules.emplace_back(cnd_last);
+                else
+                    parallel_modules.emplace_back(cnd_last);
             }
         }
 
-        // Load independent modules in parallel
-        auto thread_function = [&] {
-            std::unique_lock lk(vector_lock);
-            while (!mods_path_to_load.empty()) {
-                auto ret_load = true;
-                auto mod_to_load = std::move(mods_path_to_load.back());
-                mods_path_to_load.pop_back();
+        cv_load_module.notify_all();
+        cv_update_module.wait(lock, [&](){
+            return parallel_modules.empty() &&
+                   sequential_modules.empty(); });
 
-                lk.unlock();
-                ret_load &= LoadWithAliases(mod_to_load, true);
-                lk.lock();
-                if (!ret_load) {
-                    ret &= ret_load;
-                }
-            }
-        };
-
-        std::generate_n(std::back_inserter(threads), num_threads,
-                        [&] { return std::thread(thread_function); });
-
-        // Wait for the threads.
-        for (auto& thread : threads) {
-            thread.join();
-        }
-
-        if (!ret) return ret;
+        if (!ret.load())
+            break;
 
         std::lock_guard guard(module_loaded_lock_);
-        // Remove loaded module form mod_with_deps and soft dependencies of other modules
+        // Remove loaded module from mods_with_deps
         for (const auto& module_loaded : module_loaded_)
-            mod_with_deps.erase(module_loaded);
+            mods_with_deps.erase(module_loaded);
 
-        // Remove loaded module form dependencies of other modules which are not loaded yet
+        // Remove loaded module from dependency list
         for (const auto& module_loaded_path : module_loaded_paths_) {
-            for (auto& [mod, deps] : mod_with_deps) {
+            for (auto& [mod, deps] : mods_with_deps) {
                 auto it = std::find(deps.begin(), deps.end(), module_loaded_path);
-                if (it != deps.end()) {
+                if (it != deps.end())
                     deps.erase(it);
-                }
             }
         }
     }
 
-    return ret;
+    finish.store(true);
+    cv_load_module.notify_all();
+
+    for (auto& thread : threads) {
+        thread.join();
+    }
+
+    return ret.load();
 }
 
 bool Modprobe::LoadListedModules(bool strict) {
@@ -641,6 +670,19 @@
     return rv;
 }
 
+bool Modprobe::IsLoadSequential(const std::string& module)
+{
+    std::string str = "load_sequential=1";
+    auto it = module_options_[module].find(str);
+
+    if (it != std::string::npos) {
+        module_options_[module].erase(it, it + str.size());
+        return true;
+    }
+
+    return false;
+}
+
 bool Modprobe::GetAllDependencies(const std::string& module,
                                   std::vector<std::string>* pre_dependencies,
                                   std::vector<std::string>* dependencies,