init/epoll: Make Epoll::Wait() easier to use

Invoke the callback functions from inside Epoll::Wait() instead of
returning a vector with pointers to callback functions. Remove handlers
after handler invocation finished to prevent that self-removal triggers
a use-after-free.

The CL that made Epoll::Wait() return a vector is available at
https://android-review.googlesource.com/c/platform/system/core/+/1112042.

Bug: 213617178
Change-Id: I52c6ade5746a911510746f83802684f2d9cfb429
Signed-off-by: Bart Van Assche <bvanassche@google.com>
diff --git a/init/epoll.cpp b/init/epoll.cpp
index 3a830ce..f814c65 100644
--- a/init/epoll.cpp
+++ b/init/epoll.cpp
@@ -69,9 +69,11 @@
     if (epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr) == -1) {
         return ErrnoError() << "epoll_ctl failed to remove fd";
     }
-    if (epoll_handlers_.erase(fd) != 1) {
+    auto it = epoll_handlers_.find(fd);
+    if (it == epoll_handlers_.end()) {
         return Error() << "Attempting to remove epoll handler for FD without an existing handler";
     }
+    to_remove_.insert(it->first);
     return {};
 }
 
@@ -79,8 +81,7 @@
     first_callback_ = std::move(first_callback);
 }
 
-Result<std::vector<std::shared_ptr<Epoll::Handler>>> Epoll::Wait(
-        std::optional<std::chrono::milliseconds> timeout) {
+Result<int> Epoll::Wait(std::optional<std::chrono::milliseconds> timeout) {
     int timeout_ms = -1;
     if (timeout && timeout->count() < INT_MAX) {
         timeout_ms = timeout->count();
@@ -94,7 +95,6 @@
     if (num_events > 0 && first_callback_) {
         first_callback_();
     }
-    std::vector<std::shared_ptr<Handler>> pending_functions;
     for (int i = 0; i < num_events; ++i) {
         const auto it = epoll_handlers_.find(ev[i].data.fd);
         if (it == epoll_handlers_.end()) {
@@ -107,10 +107,13 @@
             // Log something informational.
             LOG(ERROR) << "Received unexpected epoll event set: " << ev[i].events;
         }
-        pending_functions.emplace_back(info.handler);
+        (*info.handler)();
+        for (auto fd : to_remove_) {
+            epoll_handlers_.erase(fd);
+        }
+        to_remove_.clear();
     }
-
-    return pending_functions;
+    return num_events;
 }
 
 }  // namespace init
diff --git a/init/epoll.h b/init/epoll.h
index e26e319..5853134 100644
--- a/init/epoll.h
+++ b/init/epoll.h
@@ -24,6 +24,7 @@
 #include <map>
 #include <memory>
 #include <optional>
+#include <unordered_set>
 #include <vector>
 
 #include <android-base/unique_fd.h>
@@ -43,8 +44,7 @@
     Result<void> RegisterHandler(int fd, Handler handler, uint32_t events = EPOLLIN);
     Result<void> UnregisterHandler(int fd);
     void SetFirstCallback(std::function<void()> first_callback);
-    Result<std::vector<std::shared_ptr<Handler>>> Wait(
-            std::optional<std::chrono::milliseconds> timeout);
+    Result<int> Wait(std::optional<std::chrono::milliseconds> timeout);
 
   private:
     struct Info {
@@ -55,6 +55,7 @@
     android::base::unique_fd epoll_fd_;
     std::map<int, Info> epoll_handlers_;
     std::function<void()> first_callback_;
+    std::unordered_set<int> to_remove_;
 };
 
 }  // namespace init
diff --git a/init/epoll_test.cpp b/init/epoll_test.cpp
index 3f8b5a4..7105a68 100644
--- a/init/epoll_test.cpp
+++ b/init/epoll_test.cpp
@@ -60,14 +60,9 @@
     uint8_t byte = 0xee;
     ASSERT_TRUE(android::base::WriteFully(fds[1], &byte, sizeof(byte)));
 
-    auto results = epoll.Wait({});
-    ASSERT_RESULT_OK(results);
-    ASSERT_EQ(results->size(), size_t(1));
-
-    for (const auto& function : *results) {
-        (*function)();
-        (*function)();
-    }
+    auto epoll_result = epoll.Wait({});
+    ASSERT_RESULT_OK(epoll_result);
+    ASSERT_EQ(*epoll_result, 1);
     ASSERT_TRUE(handler_invoked);
 }
 
diff --git a/init/init.cpp b/init/init.cpp
index f9e7c6e..837d955 100644
--- a/init/init.cpp
+++ b/init/init.cpp
@@ -1177,14 +1177,10 @@
             if (am.HasMoreCommands()) epoll_timeout = 0ms;
         }
 
-        auto pending_functions = epoll.Wait(epoll_timeout);
-        if (!pending_functions.ok()) {
-            LOG(ERROR) << pending_functions.error();
-        } else if (!pending_functions->empty()) {
-            for (const auto& function : *pending_functions) {
-                (*function)();
-            }
-        } else if (Service::is_exec_service_running()) {
+        auto epoll_result = epoll.Wait(epoll_timeout);
+        if (!epoll_result.ok()) {
+            LOG(ERROR) << epoll_result.error();
+        } else if (*epoll_result <= 0 && Service::is_exec_service_running()) {
             static bool dumped_diagnostics = false;
             std::chrono::duration<double> waited =
                     std::chrono::steady_clock::now() - Service::exec_service_started();
diff --git a/init/keychords_test.cpp b/init/keychords_test.cpp
index 8a333a2..5789bf5 100644
--- a/init/keychords_test.cpp
+++ b/init/keychords_test.cpp
@@ -212,11 +212,8 @@
 }
 
 void TestFrame::RelaxForMs(std::chrono::milliseconds wait) {
-    auto pending_functions = epoll_.Wait(wait);
-    ASSERT_RESULT_OK(pending_functions);
-    for (const auto& function : *pending_functions) {
-        (*function)();
-    }
+    auto epoll_result = epoll_.Wait(wait);
+    ASSERT_RESULT_OK(epoll_result);
 }
 
 void TestFrame::SetChord(int key, bool value) {
diff --git a/init/property_service.cpp b/init/property_service.cpp
index c2ba8d5..f3550a1 100644
--- a/init/property_service.cpp
+++ b/init/property_service.cpp
@@ -1381,13 +1381,9 @@
     }
 
     while (true) {
-        auto pending_functions = epoll.Wait(std::nullopt);
-        if (!pending_functions.ok()) {
-            LOG(ERROR) << pending_functions.error();
-        } else {
-            for (const auto& function : *pending_functions) {
-                (*function)();
-            }
+        auto epoll_result = epoll.Wait(std::nullopt);
+        if (!epoll_result.ok()) {
+            LOG(ERROR) << epoll_result.error();
         }
     }
 }