diff --git a/DrmHwcTwo.cpp b/DrmHwcTwo.cpp
index e4c8b1f..b7e20f0 100644
--- a/DrmHwcTwo.cpp
+++ b/DrmHwcTwo.cpp
@@ -54,8 +54,8 @@
     return HWC2::Error::NoResources;
   }
   displays_.emplace(std::piecewise_construct, std::forward_as_tuple(displ),
-                    std::forward_as_tuple(&resource_manager_, drm, displ,
-                                          type));
+                    std::forward_as_tuple(&resource_manager_, drm, displ, type,
+                                          this));
 
   DrmCrtc *crtc = drm->GetCrtcForDisplay(static_cast<int>(displ));
   if (!crtc) {
@@ -183,24 +183,23 @@
                                         hwc2_function_pointer_t function) {
   supported(__func__);
 
+  std::unique_lock<std::mutex> lock(callback_lock_);
+
   switch (static_cast<HWC2::Callback>(descriptor)) {
     case HWC2::Callback::Hotplug: {
-      SetHotplugCallback(data, function);
+      hotplug_callback_ = std::make_pair(HWC2_PFN_HOTPLUG(function), data);
+      lock.unlock();
       const auto &drm_devices = resource_manager_.getDrmDevices();
       for (const auto &device : drm_devices)
         HandleInitialHotplugState(device.get());
       break;
     }
     case HWC2::Callback::Refresh: {
-      for (std::pair<const hwc2_display_t, DrmHwcTwo::HwcDisplay> &d :
-           displays_)
-        d.second.RegisterRefreshCallback(data, function);
+      refresh_callback_ = std::make_pair(HWC2_PFN_REFRESH(function), data);
       break;
     }
     case HWC2::Callback::Vsync: {
-      for (std::pair<const hwc2_display_t, DrmHwcTwo::HwcDisplay> &d :
-           displays_)
-        d.second.RegisterVsyncCallback(data, function);
+      vsync_callback_ = std::make_pair(HWC2_PFN_VSYNC(function), data);
       break;
     }
     default:
@@ -211,8 +210,9 @@
 
 DrmHwcTwo::HwcDisplay::HwcDisplay(ResourceManager *resource_manager,
                                   DrmDevice *drm, hwc2_display_t handle,
-                                  HWC2::DisplayType type)
-    : resource_manager_(resource_manager),
+                                  HWC2::DisplayType type, DrmHwcTwo *hwc2)
+    : hwc2_(hwc2),
+      resource_manager_(resource_manager),
       drm_(drm),
       handle_(handle),
       type_(type),
@@ -240,7 +240,14 @@
   }
 
   int display = static_cast<int>(handle_);
-  int ret = compositor_.Init(resource_manager_, display);
+  int ret = compositor_.Init(resource_manager_, display, [this] {
+    /* refresh callback */
+    const std::lock_guard<std::mutex> lock(hwc2_->callback_lock_);
+    if (hwc2_->refresh_callback_.first != nullptr &&
+        hwc2_->refresh_callback_.second != nullptr) {
+      hwc2_->refresh_callback_.first(hwc2_->refresh_callback_.second, handle_);
+    }
+  });
   if (ret) {
     ALOGE("Failed display compositor init for display %d (%d)", display, ret);
     return HWC2::Error::NoResources;
@@ -271,7 +278,15 @@
     return HWC2::Error::BadDisplay;
   }
 
-  ret = vsync_worker_.Init(drm_, display);
+  ret = vsync_worker_.Init(drm_, display, [this](int64_t timestamp) {
+    /* vsync callback */
+    const std::lock_guard<std::mutex> lock(hwc2_->callback_lock_);
+    if (hwc2_->vsync_callback_.first != nullptr &&
+        hwc2_->vsync_callback_.second != nullptr) {
+      hwc2_->vsync_callback_.first(hwc2_->vsync_callback_.second, handle_,
+                                   timestamp);
+    }
+  });
   if (ret) {
     ALOGE("Failed to create event worker for d=%d %d\n", display, ret);
     return HWC2::Error::BadDisplay;
@@ -296,18 +311,6 @@
   return SetActiveConfig(connector_->get_preferred_mode_id());
 }
 
-void DrmHwcTwo::HwcDisplay::RegisterVsyncCallback(
-    hwc2_callback_data_t data, hwc2_function_pointer_t func) {
-  supported(__func__);
-  vsync_worker_.RegisterClientCallback(data, func);
-}
-
-void DrmHwcTwo::HwcDisplay::RegisterRefreshCallback(
-    hwc2_callback_data_t data, hwc2_function_pointer_t func) {
-  supported(__func__);
-  compositor_.SetRefreshCallback(data, func);
-}
-
 HWC2::Error DrmHwcTwo::HwcDisplay::AcceptDisplayChanges() {
   supported(__func__);
   for (std::pair<const hwc2_layer_t, DrmHwcTwo::HwcLayer> &l : layers_)
@@ -1213,13 +1216,15 @@
 }
 
 void DrmHwcTwo::HandleDisplayHotplug(hwc2_display_t displayid, int state) {
-  const std::lock_guard<std::mutex> lock(hotplug_callback_lock);
+  const std::lock_guard<std::mutex> lock(callback_lock_);
 
-  if (hotplug_callback_hook_ && hotplug_callback_data_)
-    hotplug_callback_hook_(hotplug_callback_data_, displayid,
-                           state == DRM_MODE_CONNECTED
-                               ? HWC2_CONNECTION_CONNECTED
-                               : HWC2_CONNECTION_DISCONNECTED);
+  if (hotplug_callback_.first != nullptr &&
+      hotplug_callback_.second != nullptr) {
+    hotplug_callback_.first(hotplug_callback_.second, displayid,
+                            state == DRM_MODE_CONNECTED
+                                ? HWC2_CONNECTION_CONNECTED
+                                : HWC2_CONNECTION_DISCONNECTED);
+  }
 }
 
 void DrmHwcTwo::HandleInitialHotplugState(DrmDevice *drmDevice) {
diff --git a/DrmHwcTwo.h b/DrmHwcTwo.h
index 2055b7b..16c4dac 100644
--- a/DrmHwcTwo.h
+++ b/DrmHwcTwo.h
@@ -42,16 +42,11 @@
 
   HWC2::Error Init();
 
-  hwc2_callback_data_t hotplug_callback_data_ = NULL;
-  HWC2_PFN_HOTPLUG hotplug_callback_hook_ = NULL;
-  std::mutex hotplug_callback_lock;
+  std::pair<HWC2_PFN_HOTPLUG, hwc2_callback_data_t> hotplug_callback_{};
+  std::pair<HWC2_PFN_VSYNC, hwc2_callback_data_t> vsync_callback_{};
+  std::pair<HWC2_PFN_REFRESH, hwc2_callback_data_t> refresh_callback_{};
 
-  void SetHotplugCallback(hwc2_callback_data_t data,
-                          hwc2_function_pointer_t hook) {
-    const std::lock_guard<std::mutex> lock(hotplug_callback_lock);
-    hotplug_callback_data_ = data;
-    hotplug_callback_hook_ = reinterpret_cast<HWC2_PFN_HOTPLUG>(hook);
-  }
+  std::mutex callback_lock_;
 
   class HwcLayer {
    public:
@@ -147,14 +142,10 @@
   class HwcDisplay {
    public:
     HwcDisplay(ResourceManager *resource_manager, DrmDevice *drm,
-               hwc2_display_t handle, HWC2::DisplayType type);
+               hwc2_display_t handle, HWC2::DisplayType type, DrmHwcTwo *hwc2);
     HwcDisplay(const HwcDisplay &) = delete;
     HWC2::Error Init(std::vector<DrmPlane *> *planes);
 
-    void RegisterVsyncCallback(hwc2_callback_data_t data,
-                               hwc2_function_pointer_t func);
-    void RegisterRefreshCallback(hwc2_callback_data_t data,
-                                 hwc2_function_pointer_t func);
     HWC2::Error CreateComposition(bool test);
     std::vector<DrmHwcTwo::HwcLayer *> GetOrderLayersByZPos();
 
@@ -304,6 +295,8 @@
 
     constexpr static size_t MATRIX_SIZE = 16;
 
+    DrmHwcTwo *hwc2_;
+
     ResourceManager *resource_manager_;
     DrmDevice *drm_;
     DrmDisplayCompositor compositor_;
diff --git a/compositor/DrmDisplayCompositor.cpp b/compositor/DrmDisplayCompositor.cpp
index 08d998c..5ad2de2 100644
--- a/compositor/DrmDisplayCompositor.cpp
+++ b/compositor/DrmDisplayCompositor.cpp
@@ -49,20 +49,6 @@
   return str << flattenting_state_str[static_cast<int>(state)];
 }
 
-class CompositorVsyncCallback : public VsyncCallback {
- public:
-  explicit CompositorVsyncCallback(DrmDisplayCompositor *compositor)
-      : compositor_(compositor) {
-  }
-
-  void Callback(int display, int64_t timestamp) override {
-    compositor_->Vsync(display, timestamp);
-  }
-
- private:
-  DrmDisplayCompositor *compositor_;
-};
-
 DrmDisplayCompositor::DrmDisplayCompositor()
     : resource_manager_(nullptr),
       display_(-1),
@@ -103,7 +89,10 @@
   pthread_mutex_destroy(&lock_);
 }
 
-int DrmDisplayCompositor::Init(ResourceManager *resource_manager, int display) {
+auto DrmDisplayCompositor::Init(ResourceManager *resource_manager, int display,
+                                std::function<void()> client_refresh_callback)
+    -> int {
+  client_refresh_callback_ = std::move(client_refresh_callback);
   resource_manager_ = resource_manager;
   display_ = display;
   DrmDevice *drm = resource_manager_->GetDrmDevice(display);
@@ -118,9 +107,19 @@
   }
   planner_ = Planner::CreateInstance(drm);
 
-  vsync_worker_.Init(drm, display_);
-  auto callback = std::make_shared<CompositorVsyncCallback>(this);
-  vsync_worker_.RegisterCallback(callback);
+  vsync_worker_.Init(drm, display_, [this](int64_t timestamp) {
+    AutoLock lock(&lock_, "DrmDisplayCompositor::Init()");
+    if (lock.Lock())
+      return;
+    flatten_countdown_--;
+    if (!CountdownExpired())
+      return;
+    lock.Unlock();
+    int ret = FlattenActiveComposition();
+    ALOGV("scene flattening triggered for display %d at timestamp %" PRIu64
+          " result = %d \n",
+          display_, timestamp, ret);
+  });
 
   initialized_ = true;
   return 0;
@@ -445,9 +444,7 @@
 }
 
 int DrmDisplayCompositor::FlattenOnClient() {
-  const std::lock_guard<std::mutex> lock(refresh_callback_lock);
-
-  if (refresh_callback_hook_ && refresh_callback_data_) {
+  if (client_refresh_callback_) {
     {
       AutoLock lock(&lock_, __func__);
       if (!IsFlatteningNeeded()) {
@@ -463,7 +460,7 @@
         "No writeback connector available, "
         "falling back to client composition");
     SetFlattening(FlatteningState::kClientRequested);
-    refresh_callback_hook_(refresh_callback_data_, display_);
+    client_refresh_callback_();
     return 0;
   }
 
@@ -479,20 +476,6 @@
   return flatten_countdown_ <= 0;
 }
 
-void DrmDisplayCompositor::Vsync(int display, int64_t timestamp) {
-  AutoLock lock(&lock_, __func__);
-  if (lock.Lock())
-    return;
-  flatten_countdown_--;
-  if (!CountdownExpired())
-    return;
-  lock.Unlock();
-  int ret = FlattenActiveComposition();
-  ALOGV("scene flattening triggered for display %d at timestamp %" PRIu64
-        " result = %d \n",
-        display, timestamp, ret);
-}
-
 void DrmDisplayCompositor::Dump(std::ostringstream *out) const {
   int ret = pthread_mutex_lock(&lock_);
   if (ret)
diff --git a/compositor/DrmDisplayCompositor.h b/compositor/DrmDisplayCompositor.h
index c0eed0c..afdb79e 100644
--- a/compositor/DrmDisplayCompositor.h
+++ b/compositor/DrmDisplayCompositor.h
@@ -52,25 +52,14 @@
   DrmDisplayCompositor();
   ~DrmDisplayCompositor();
 
-  int Init(ResourceManager *resource_manager, int display);
-
-  hwc2_callback_data_t refresh_callback_data_ = NULL;
-  HWC2_PFN_REFRESH refresh_callback_hook_ = NULL;
-  std::mutex refresh_callback_lock;
-
-  void SetRefreshCallback(hwc2_callback_data_t data,
-                          hwc2_function_pointer_t hook) {
-    const std::lock_guard<std::mutex> lock(refresh_callback_lock);
-    refresh_callback_data_ = data;
-    refresh_callback_hook_ = reinterpret_cast<HWC2_PFN_REFRESH>(hook);
-  }
+  auto Init(ResourceManager *resource_manager, int display,
+            std::function<void()> client_refresh_callback) -> int;
 
   std::unique_ptr<DrmDisplayComposition> CreateInitializedComposition() const;
   int ApplyComposition(std::unique_ptr<DrmDisplayComposition> composition);
   int TestComposition(DrmDisplayComposition *composition);
   int Composite();
   void Dump(std::ostringstream *out) const;
-  void Vsync(int display, int64_t timestamp);
   void ClearDisplay();
   UniqueFd TakeOutFence() {
     if (!active_composition_) {
@@ -86,6 +75,7 @@
   std::tuple<uint32_t, uint32_t, int> GetActiveModeResolution();
 
  private:
+  std::function<void()> client_refresh_callback_;
   struct ModeState {
     bool needs_modeset = false;
     DrmMode mode;
diff --git a/drm/VSyncWorker.cpp b/drm/VSyncWorker.cpp
index 1c0de21..6e92838 100644
--- a/drm/VSyncWorker.cpp
+++ b/drm/VSyncWorker.cpp
@@ -37,27 +37,16 @@
       last_timestamp_(-1) {
 }
 
-int VSyncWorker::Init(DrmDevice *drm, int display) {
+auto VSyncWorker::Init(DrmDevice *drm, int display,
+                       std::function<void(uint64_t /*timestamp*/)> callback)
+    -> int {
   drm_ = drm;
   display_ = display;
+  callback_ = std::move(callback);
 
   return InitWorker();
 }
 
-void VSyncWorker::RegisterCallback(std::shared_ptr<VsyncCallback> callback) {
-  Lock();
-  callback_ = std::move(callback);
-  Unlock();
-}
-
-void VSyncWorker::RegisterClientCallback(hwc2_callback_data_t data,
-                                         hwc2_function_pointer_t hook) {
-  Lock();
-  vsync_callback_data_ = data;
-  vsync_callback_hook_ = (HWC2_PFN_VSYNC)hook;
-  Unlock();
-}
-
 void VSyncWorker::VSyncControl(bool enabled) {
   Lock();
   enabled_ = enabled;
@@ -133,7 +122,6 @@
   }
 
   int display = display_;
-  std::shared_ptr<VsyncCallback> callback(callback_);
   Unlock();
 
   DrmCrtc *crtc = drm_->GetCrtcForDisplay(display);
@@ -145,8 +133,9 @@
 
   drmVBlank vblank;
   memset(&vblank, 0, sizeof(vblank));
-  vblank.request.type = (drmVBlankSeqType)(
-      DRM_VBLANK_RELATIVE | (high_crtc & DRM_VBLANK_HIGH_CRTC_MASK));
+  vblank.request.type = (drmVBlankSeqType)(DRM_VBLANK_RELATIVE |
+                                           (high_crtc &
+                                            DRM_VBLANK_HIGH_CRTC_MASK));
   vblank.request.sequence = 1;
 
   int64_t timestamp = 0;
@@ -166,13 +155,9 @@
   if (!enabled_)
     return;
 
-  if (callback)
-    callback->Callback(display, timestamp);
-
-  Lock();
-  if (enabled_ && vsync_callback_hook_ && vsync_callback_data_)
-    vsync_callback_hook_(vsync_callback_data_, display, timestamp);
-  Unlock();
+  if (callback_) {
+    callback_(timestamp);
+  }
 
   last_timestamp_ = timestamp;
 }
diff --git a/drm/VSyncWorker.h b/drm/VSyncWorker.h
index b43918c..74ff487 100644
--- a/drm/VSyncWorker.h
+++ b/drm/VSyncWorker.h
@@ -23,6 +23,7 @@
 #include <stdint.h>
 
 #include <atomic>
+#include <functional>
 #include <map>
 
 #include "DrmDevice.h"
@@ -30,21 +31,13 @@
 
 namespace android {
 
-class VsyncCallback {
- public:
-  virtual ~VsyncCallback() = default;
-  virtual void Callback(int display, int64_t timestamp) = 0;
-};
-
 class VSyncWorker : public Worker {
  public:
   VSyncWorker();
   ~VSyncWorker() override = default;
 
-  int Init(DrmDevice *drm, int display);
-  void RegisterCallback(std::shared_ptr<VsyncCallback> callback);
-  void RegisterClientCallback(hwc2_callback_data_t data,
-                              hwc2_function_pointer_t hook);
+  auto Init(DrmDevice *drm, int display,
+            std::function<void(uint64_t /*timestamp*/)> callback) -> int;
 
   void VSyncControl(bool enabled);
 
@@ -57,17 +50,11 @@
 
   DrmDevice *drm_;
 
-  // shared_ptr since we need to use this outside of the thread lock (to
-  // actually call the hook) and we don't want the memory freed until we're
-  // done
-  std::shared_ptr<VsyncCallback> callback_ = NULL;
+  std::function<void(uint64_t /*timestamp*/)> callback_;
 
   int display_;
   std::atomic_bool enabled_;
   int64_t last_timestamp_;
-
-  hwc2_callback_data_t vsync_callback_data_ = NULL;
-  HWC2_PFN_VSYNC vsync_callback_hook_ = NULL;
 };
 }  // namespace android
 
