Use std::shared_ptr for SpriteController

Remove RefBase from SpriteController, and use std::shared_ptr. We cannot
migrate to std::unique_ptr because we have to post messages to the
handler, which needs to have a weak reference to the object.

Bug: 278783893
Test: presubmit
Change-Id: I0ea4bb220e5b1866375ed39335f9035cd4bb766c
diff --git a/libs/input/MouseCursorController.cpp b/libs/input/MouseCursorController.cpp
index c3ad767..6a46544 100644
--- a/libs/input/MouseCursorController.cpp
+++ b/libs/input/MouseCursorController.cpp
@@ -47,7 +47,7 @@
     mLocked.pointerX = 0;
     mLocked.pointerY = 0;
     mLocked.pointerAlpha = 0.0f; // pointer is initially faded
-    mLocked.pointerSprite = mContext.getSpriteController()->createSprite();
+    mLocked.pointerSprite = mContext.getSpriteController().createSprite();
     mLocked.updatePointerIcon = false;
     mLocked.requestedPointerType = PointerIconStyle::TYPE_NOT_SPECIFIED;
     mLocked.resolvedPointerType = PointerIconStyle::TYPE_NOT_SPECIFIED;
@@ -325,8 +325,8 @@
     }
 
     if (timestamp - mLocked.lastFrameUpdatedTime > iter->second.durationPerFrame) {
-        sp<SpriteController> spriteController = mContext.getSpriteController();
-        spriteController->openTransaction();
+        auto& spriteController = mContext.getSpriteController();
+        spriteController.openTransaction();
 
         int incr = (timestamp - mLocked.lastFrameUpdatedTime) / iter->second.durationPerFrame;
         mLocked.animationFrameIndex += incr;
@@ -336,7 +336,7 @@
         }
         mLocked.pointerSprite->setIcon(iter->second.animationFrames[mLocked.animationFrameIndex]);
 
-        spriteController->closeTransaction();
+        spriteController.closeTransaction();
     }
     // Keep animating.
     return true;
@@ -346,8 +346,8 @@
     if (!mLocked.viewport.isValid()) {
         return;
     }
-    sp<SpriteController> spriteController = mContext.getSpriteController();
-    spriteController->openTransaction();
+    auto& spriteController = mContext.getSpriteController();
+    spriteController.openTransaction();
 
     mLocked.pointerSprite->setLayer(Sprite::BASE_LAYER_POINTER);
     mLocked.pointerSprite->setPosition(mLocked.pointerX, mLocked.pointerY);
@@ -392,7 +392,7 @@
         mLocked.updatePointerIcon = false;
     }
 
-    spriteController->closeTransaction();
+    spriteController.closeTransaction();
 }
 
 void MouseCursorController::loadResourcesLocked(bool getAdditionalMouseResources) REQUIRES(mLock) {
diff --git a/libs/input/PointerController.cpp b/libs/input/PointerController.cpp
index bb3d9d7..435452c 100644
--- a/libs/input/PointerController.cpp
+++ b/libs/input/PointerController.cpp
@@ -63,7 +63,7 @@
 
 std::shared_ptr<PointerController> PointerController::create(
         const sp<PointerControllerPolicyInterface>& policy, const sp<Looper>& looper,
-        const sp<SpriteController>& spriteController) {
+        SpriteController& spriteController) {
     // using 'new' to access non-public constructor
     std::shared_ptr<PointerController> controller = std::shared_ptr<PointerController>(
             new PointerController(policy, looper, spriteController));
@@ -85,8 +85,7 @@
 }
 
 PointerController::PointerController(const sp<PointerControllerPolicyInterface>& policy,
-                                     const sp<Looper>& looper,
-                                     const sp<SpriteController>& spriteController)
+                                     const sp<Looper>& looper, SpriteController& spriteController)
       : PointerController(
                 policy, looper, spriteController,
                 [](const sp<android::gui::WindowInfosListener>& listener) {
@@ -97,8 +96,7 @@
                 }) {}
 
 PointerController::PointerController(const sp<PointerControllerPolicyInterface>& policy,
-                                     const sp<Looper>& looper,
-                                     const sp<SpriteController>& spriteController,
+                                     const sp<Looper>& looper, SpriteController& spriteController,
                                      WindowListenerConsumer registerListener,
                                      WindowListenerConsumer unregisterListener)
       : mContext(policy, looper, spriteController, *this),
diff --git a/libs/input/PointerController.h b/libs/input/PointerController.h
index 62ee743..c7e772d 100644
--- a/libs/input/PointerController.h
+++ b/libs/input/PointerController.h
@@ -47,7 +47,7 @@
 public:
     static std::shared_ptr<PointerController> create(
             const sp<PointerControllerPolicyInterface>& policy, const sp<Looper>& looper,
-            const sp<SpriteController>& spriteController);
+            SpriteController& spriteController);
 
     ~PointerController() override;
 
@@ -83,13 +83,12 @@
 
     // Constructor used to test WindowInfosListener registration.
     PointerController(const sp<PointerControllerPolicyInterface>& policy, const sp<Looper>& looper,
-                      const sp<SpriteController>& spriteController,
-                      WindowListenerConsumer registerListener,
+                      SpriteController& spriteController, WindowListenerConsumer registerListener,
                       WindowListenerConsumer unregisterListener);
 
 private:
     PointerController(const sp<PointerControllerPolicyInterface>& policy, const sp<Looper>& looper,
-                      const sp<SpriteController>& spriteController);
+                      SpriteController& spriteController);
 
     friend PointerControllerContext::LooperCallback;
     friend PointerControllerContext::MessageHandler;
diff --git a/libs/input/PointerControllerContext.cpp b/libs/input/PointerControllerContext.cpp
index c1545107..15c3517 100644
--- a/libs/input/PointerControllerContext.cpp
+++ b/libs/input/PointerControllerContext.cpp
@@ -32,7 +32,7 @@
 
 PointerControllerContext::PointerControllerContext(
         const sp<PointerControllerPolicyInterface>& policy, const sp<Looper>& looper,
-        const sp<SpriteController>& spriteController, PointerController& controller)
+        SpriteController& spriteController, PointerController& controller)
       : mPolicy(policy),
         mLooper(looper),
         mSpriteController(spriteController),
@@ -93,7 +93,7 @@
     return mPolicy;
 }
 
-sp<SpriteController> PointerControllerContext::getSpriteController() {
+SpriteController& PointerControllerContext::getSpriteController() {
     return mSpriteController;
 }
 
diff --git a/libs/input/PointerControllerContext.h b/libs/input/PointerControllerContext.h
index f6f5d3b..98c3988 100644
--- a/libs/input/PointerControllerContext.h
+++ b/libs/input/PointerControllerContext.h
@@ -92,7 +92,7 @@
 class PointerControllerContext {
 public:
     PointerControllerContext(const sp<PointerControllerPolicyInterface>& policy,
-                             const sp<Looper>& looper, const sp<SpriteController>& spriteController,
+                             const sp<Looper>& looper, SpriteController& spriteController,
                              PointerController& controller);
     ~PointerControllerContext();
 
@@ -109,7 +109,7 @@
     void setCallbackController(std::shared_ptr<PointerController> controller);
 
     sp<PointerControllerPolicyInterface> getPolicy();
-    sp<SpriteController> getSpriteController();
+    SpriteController& getSpriteController();
 
     void handleDisplayEvents();
 
@@ -163,7 +163,7 @@
 
     sp<PointerControllerPolicyInterface> mPolicy;
     sp<Looper> mLooper;
-    sp<SpriteController> mSpriteController;
+    SpriteController& mSpriteController;
     sp<MessageHandler> mHandler;
     sp<LooperCallback> mCallback;
 
diff --git a/libs/input/SpriteController.cpp b/libs/input/SpriteController.cpp
index d40f49e..6dc45a6 100644
--- a/libs/input/SpriteController.cpp
+++ b/libs/input/SpriteController.cpp
@@ -37,10 +37,10 @@
     mLocked.deferredSpriteUpdate = false;
 }
 
-void SpriteController::setHandlerController(const sp<android::SpriteController>& controller) {
-    // Initialize the weak message handler outside the constructor, because we cannot get a strong
-    // pointer to self in the constructor as the initial ref count is only incremented after
-    // construction.
+void SpriteController::setHandlerController(
+        const std::shared_ptr<android::SpriteController>& controller) {
+    // Initialize the weak message handler outside the constructor, because we cannot get a shared
+    // pointer to self in the constructor.
     mHandler->spriteController = controller;
 }
 
@@ -54,7 +54,7 @@
 }
 
 sp<Sprite> SpriteController::createSprite() {
-    return sp<SpriteImpl>::make(sp<SpriteController>::fromExisting(this));
+    return sp<SpriteImpl>::make(*this);
 }
 
 void SpriteController::openTransaction() {
@@ -352,7 +352,7 @@
 // --- SpriteController::Handler ---
 
 void SpriteController::Handler::handleMessage(const android::Message& message) {
-    auto controller = spriteController.promote();
+    auto controller = spriteController.lock();
     if (!controller) {
         return;
     }
@@ -369,22 +369,21 @@
 
 // --- SpriteController::SpriteImpl ---
 
-SpriteController::SpriteImpl::SpriteImpl(const sp<SpriteController>& controller)
-      : mController(controller) {}
+SpriteController::SpriteImpl::SpriteImpl(SpriteController& controller) : mController(controller) {}
 
 SpriteController::SpriteImpl::~SpriteImpl() {
-    AutoMutex _m(mController->mLock);
+    AutoMutex _m(mController.mLock);
 
     // Let the controller take care of deleting the last reference to sprite
     // surfaces so that we do not block the caller on an IPC here.
     if (mLocked.state.surfaceControl != NULL) {
-        mController->disposeSurfaceLocked(mLocked.state.surfaceControl);
+        mController.disposeSurfaceLocked(mLocked.state.surfaceControl);
         mLocked.state.surfaceControl.clear();
     }
 }
 
 void SpriteController::SpriteImpl::setIcon(const SpriteIcon& icon) {
-    AutoMutex _l(mController->mLock);
+    AutoMutex _l(mController.mLock);
 
     uint32_t dirty;
     if (icon.isValid()) {
@@ -414,7 +413,7 @@
 }
 
 void SpriteController::SpriteImpl::setVisible(bool visible) {
-    AutoMutex _l(mController->mLock);
+    AutoMutex _l(mController.mLock);
 
     if (mLocked.state.visible != visible) {
         mLocked.state.visible = visible;
@@ -423,7 +422,7 @@
 }
 
 void SpriteController::SpriteImpl::setPosition(float x, float y) {
-    AutoMutex _l(mController->mLock);
+    AutoMutex _l(mController.mLock);
 
     if (mLocked.state.positionX != x || mLocked.state.positionY != y) {
         mLocked.state.positionX = x;
@@ -433,7 +432,7 @@
 }
 
 void SpriteController::SpriteImpl::setLayer(int32_t layer) {
-    AutoMutex _l(mController->mLock);
+    AutoMutex _l(mController.mLock);
 
     if (mLocked.state.layer != layer) {
         mLocked.state.layer = layer;
@@ -442,7 +441,7 @@
 }
 
 void SpriteController::SpriteImpl::setAlpha(float alpha) {
-    AutoMutex _l(mController->mLock);
+    AutoMutex _l(mController.mLock);
 
     if (mLocked.state.alpha != alpha) {
         mLocked.state.alpha = alpha;
@@ -452,7 +451,7 @@
 
 void SpriteController::SpriteImpl::setTransformationMatrix(
         const SpriteTransformationMatrix& matrix) {
-    AutoMutex _l(mController->mLock);
+    AutoMutex _l(mController.mLock);
 
     if (mLocked.state.transformationMatrix != matrix) {
         mLocked.state.transformationMatrix = matrix;
@@ -461,7 +460,7 @@
 }
 
 void SpriteController::SpriteImpl::setDisplayId(int32_t displayId) {
-    AutoMutex _l(mController->mLock);
+    AutoMutex _l(mController.mLock);
 
     if (mLocked.state.displayId != displayId) {
         mLocked.state.displayId = displayId;
@@ -474,7 +473,7 @@
     mLocked.state.dirty |= dirty;
 
     if (!wasDirty) {
-        mController->invalidateSpriteLocked(sp<SpriteImpl>::fromExisting(this));
+        mController.invalidateSpriteLocked(sp<SpriteImpl>::fromExisting(this));
     }
 }
 
diff --git a/libs/input/SpriteController.h b/libs/input/SpriteController.h
index 3144401..04ecb38 100644
--- a/libs/input/SpriteController.h
+++ b/libs/input/SpriteController.h
@@ -109,18 +109,19 @@
  *
  * Clients are responsible for animating sprites by periodically updating their properties.
  */
-class SpriteController : public RefBase {
-protected:
-    virtual ~SpriteController();
-
+class SpriteController {
 public:
     using ParentSurfaceProvider = std::function<sp<SurfaceControl>(int /*displayId*/)>;
     SpriteController(const sp<Looper>& looper, int32_t overlayLayer, ParentSurfaceProvider parent);
+    SpriteController(const SpriteController&) = delete;
+    SpriteController& operator=(const SpriteController&) = delete;
+    virtual ~SpriteController();
 
     /* Initialize the callback for the message handler. */
-    void setHandlerController(const sp<SpriteController>& controller);
+    void setHandlerController(const std::shared_ptr<SpriteController>& controller);
 
-    /* Creates a new sprite, initially invisible. */
+    /* Creates a new sprite, initially invisible. The lifecycle of the sprite must not extend beyond
+     * the lifecycle of this SpriteController. */
     virtual sp<Sprite> createSprite();
 
     /* Opens or closes a transaction to perform a batch of sprite updates as part of
@@ -137,7 +138,7 @@
         enum { MSG_UPDATE_SPRITES, MSG_DISPOSE_SURFACES };
 
         void handleMessage(const Message& message) override;
-        wp<SpriteController> spriteController;
+        std::weak_ptr<SpriteController> spriteController;
     };
 
     enum {
@@ -198,7 +199,7 @@
         virtual ~SpriteImpl();
 
     public:
-        explicit SpriteImpl(const sp<SpriteController>& controller);
+        explicit SpriteImpl(SpriteController& controller);
 
         virtual void setIcon(const SpriteIcon& icon);
         virtual void setVisible(bool visible);
@@ -226,7 +227,7 @@
         }
 
     private:
-        sp<SpriteController> mController;
+        SpriteController& mController;
 
         struct Locked {
             SpriteState state;
diff --git a/libs/input/TouchSpotController.cpp b/libs/input/TouchSpotController.cpp
index c212608..b8de919 100644
--- a/libs/input/TouchSpotController.cpp
+++ b/libs/input/TouchSpotController.cpp
@@ -98,8 +98,8 @@
 #endif
 
     std::scoped_lock lock(mLock);
-    sp<SpriteController> spriteController = mContext.getSpriteController();
-    spriteController->openTransaction();
+    auto& spriteController = mContext.getSpriteController();
+    spriteController.openTransaction();
 
     // Add or move spots for fingers that are down.
     for (BitSet32 idBits(spotIdBits); !idBits.isEmpty();) {
@@ -125,7 +125,7 @@
         }
     }
 
-    spriteController->closeTransaction();
+    spriteController.closeTransaction();
 }
 
 void TouchSpotController::clearSpots() {
@@ -167,7 +167,7 @@
         sprite = mLocked.recycledSprites.back();
         mLocked.recycledSprites.pop_back();
     } else {
-        sprite = mContext.getSpriteController()->createSprite();
+        sprite = mContext.getSpriteController().createSprite();
     }
 
     // Return the new spot.
diff --git a/libs/input/tests/PointerController_test.cpp b/libs/input/tests/PointerController_test.cpp
index 8574751..3e2e43f 100644
--- a/libs/input/tests/PointerController_test.cpp
+++ b/libs/input/tests/PointerController_test.cpp
@@ -157,7 +157,7 @@
 
     sp<MockSprite> mPointerSprite;
     sp<MockPointerControllerPolicyInterface> mPolicy;
-    sp<MockSpriteController> mSpriteController;
+    std::unique_ptr<MockSpriteController> mSpriteController;
     std::shared_ptr<PointerController> mPointerController;
 
 private:
@@ -175,14 +175,13 @@
 
 PointerControllerTest::PointerControllerTest() : mPointerSprite(new NiceMock<MockSprite>),
         mLooper(new MyLooper), mThread(&PointerControllerTest::loopThread, this) {
-
-    mSpriteController = new NiceMock<MockSpriteController>(mLooper);
+    mSpriteController.reset(new NiceMock<MockSpriteController>(mLooper));
     mPolicy = new MockPointerControllerPolicyInterface();
 
     EXPECT_CALL(*mSpriteController, createSprite())
             .WillOnce(Return(mPointerSprite));
 
-    mPointerController = PointerController::create(mPolicy, mLooper, mSpriteController);
+    mPointerController = PointerController::create(mPolicy, mLooper, *mSpriteController);
 }
 
 PointerControllerTest::~PointerControllerTest() {
@@ -319,10 +318,9 @@
 class TestPointerController : public PointerController {
 public:
     TestPointerController(sp<android::gui::WindowInfosListener>& registeredListener,
-                          const sp<Looper>& looper)
+                          const sp<Looper>& looper, SpriteController& spriteController)
           : PointerController(
-                    new MockPointerControllerPolicyInterface(), looper,
-                    new NiceMock<MockSpriteController>(looper),
+                    new MockPointerControllerPolicyInterface(), looper, spriteController,
                     [&registeredListener](const sp<android::gui::WindowInfosListener>& listener) {
                         // Register listener
                         registeredListener = listener;
@@ -335,10 +333,12 @@
 
 TEST_F(PointerControllerWindowInfoListenerTest,
        doesNotCrashIfListenerCalledAfterPointerControllerDestroyed) {
+    sp<Looper> looper = new Looper(false);
+    auto spriteController = NiceMock<MockSpriteController>(looper);
     sp<android::gui::WindowInfosListener> registeredListener;
     sp<android::gui::WindowInfosListener> localListenerCopy;
     {
-        TestPointerController pointerController(registeredListener, new Looper(false));
+        TestPointerController pointerController(registeredListener, looper, spriteController);
         ASSERT_NE(nullptr, registeredListener) << "WindowInfosListener was not registered";
         localListenerCopy = registeredListener;
     }