init/epoll: Make Epoll::Wait() easier to use
Invoke the callback functions from inside Epoll::Wait() instead of
returning a vector with pointers to callback functions. Remove handlers
after handler invocation finished to prevent that self-removal triggers
a use-after-free.
The CL that made Epoll::Wait() return a vector is available at
https://android-review.googlesource.com/c/platform/system/core/+/1112042.
Bug: 213617178
Change-Id: I52c6ade5746a911510746f83802684f2d9cfb429
Signed-off-by: Bart Van Assche <bvanassche@google.com>
diff --git a/init/epoll.cpp b/init/epoll.cpp
index 3a830ce..f814c65 100644
--- a/init/epoll.cpp
+++ b/init/epoll.cpp
@@ -69,9 +69,11 @@
if (epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr) == -1) {
return ErrnoError() << "epoll_ctl failed to remove fd";
}
- if (epoll_handlers_.erase(fd) != 1) {
+ auto it = epoll_handlers_.find(fd);
+ if (it == epoll_handlers_.end()) {
return Error() << "Attempting to remove epoll handler for FD without an existing handler";
}
+ to_remove_.insert(it->first);
return {};
}
@@ -79,8 +81,7 @@
first_callback_ = std::move(first_callback);
}
-Result<std::vector<std::shared_ptr<Epoll::Handler>>> Epoll::Wait(
- std::optional<std::chrono::milliseconds> timeout) {
+Result<int> Epoll::Wait(std::optional<std::chrono::milliseconds> timeout) {
int timeout_ms = -1;
if (timeout && timeout->count() < INT_MAX) {
timeout_ms = timeout->count();
@@ -94,7 +95,6 @@
if (num_events > 0 && first_callback_) {
first_callback_();
}
- std::vector<std::shared_ptr<Handler>> pending_functions;
for (int i = 0; i < num_events; ++i) {
const auto it = epoll_handlers_.find(ev[i].data.fd);
if (it == epoll_handlers_.end()) {
@@ -107,10 +107,13 @@
// Log something informational.
LOG(ERROR) << "Received unexpected epoll event set: " << ev[i].events;
}
- pending_functions.emplace_back(info.handler);
+ (*info.handler)();
+ for (auto fd : to_remove_) {
+ epoll_handlers_.erase(fd);
+ }
+ to_remove_.clear();
}
-
- return pending_functions;
+ return num_events;
}
} // namespace init
diff --git a/init/epoll.h b/init/epoll.h
index e26e319..5853134 100644
--- a/init/epoll.h
+++ b/init/epoll.h
@@ -24,6 +24,7 @@
#include <map>
#include <memory>
#include <optional>
+#include <unordered_set>
#include <vector>
#include <android-base/unique_fd.h>
@@ -43,8 +44,7 @@
Result<void> RegisterHandler(int fd, Handler handler, uint32_t events = EPOLLIN);
Result<void> UnregisterHandler(int fd);
void SetFirstCallback(std::function<void()> first_callback);
- Result<std::vector<std::shared_ptr<Handler>>> Wait(
- std::optional<std::chrono::milliseconds> timeout);
+ Result<int> Wait(std::optional<std::chrono::milliseconds> timeout);
private:
struct Info {
@@ -55,6 +55,7 @@
android::base::unique_fd epoll_fd_;
std::map<int, Info> epoll_handlers_;
std::function<void()> first_callback_;
+ std::unordered_set<int> to_remove_;
};
} // namespace init
diff --git a/init/epoll_test.cpp b/init/epoll_test.cpp
index 3f8b5a4..7105a68 100644
--- a/init/epoll_test.cpp
+++ b/init/epoll_test.cpp
@@ -60,14 +60,9 @@
uint8_t byte = 0xee;
ASSERT_TRUE(android::base::WriteFully(fds[1], &byte, sizeof(byte)));
- auto results = epoll.Wait({});
- ASSERT_RESULT_OK(results);
- ASSERT_EQ(results->size(), size_t(1));
-
- for (const auto& function : *results) {
- (*function)();
- (*function)();
- }
+ auto epoll_result = epoll.Wait({});
+ ASSERT_RESULT_OK(epoll_result);
+ ASSERT_EQ(*epoll_result, 1);
ASSERT_TRUE(handler_invoked);
}
diff --git a/init/init.cpp b/init/init.cpp
index f9e7c6e..837d955 100644
--- a/init/init.cpp
+++ b/init/init.cpp
@@ -1177,14 +1177,10 @@
if (am.HasMoreCommands()) epoll_timeout = 0ms;
}
- auto pending_functions = epoll.Wait(epoll_timeout);
- if (!pending_functions.ok()) {
- LOG(ERROR) << pending_functions.error();
- } else if (!pending_functions->empty()) {
- for (const auto& function : *pending_functions) {
- (*function)();
- }
- } else if (Service::is_exec_service_running()) {
+ auto epoll_result = epoll.Wait(epoll_timeout);
+ if (!epoll_result.ok()) {
+ LOG(ERROR) << epoll_result.error();
+ } else if (*epoll_result <= 0 && Service::is_exec_service_running()) {
static bool dumped_diagnostics = false;
std::chrono::duration<double> waited =
std::chrono::steady_clock::now() - Service::exec_service_started();
diff --git a/init/keychords_test.cpp b/init/keychords_test.cpp
index 8a333a2..5789bf5 100644
--- a/init/keychords_test.cpp
+++ b/init/keychords_test.cpp
@@ -212,11 +212,8 @@
}
void TestFrame::RelaxForMs(std::chrono::milliseconds wait) {
- auto pending_functions = epoll_.Wait(wait);
- ASSERT_RESULT_OK(pending_functions);
- for (const auto& function : *pending_functions) {
- (*function)();
- }
+ auto epoll_result = epoll_.Wait(wait);
+ ASSERT_RESULT_OK(epoll_result);
}
void TestFrame::SetChord(int key, bool value) {
diff --git a/init/property_service.cpp b/init/property_service.cpp
index c2ba8d5..f3550a1 100644
--- a/init/property_service.cpp
+++ b/init/property_service.cpp
@@ -1381,13 +1381,9 @@
}
while (true) {
- auto pending_functions = epoll.Wait(std::nullopt);
- if (!pending_functions.ok()) {
- LOG(ERROR) << pending_functions.error();
- } else {
- for (const auto& function : *pending_functions) {
- (*function)();
- }
+ auto epoll_result = epoll.Wait(std::nullopt);
+ if (!epoll_result.ok()) {
+ LOG(ERROR) << epoll_result.error();
}
}
}