Add a cache to getService with invalidation

Test: atest binderCacheUnitTest
Bug: 333854840
Flag: LIBBINDER_CLIENT_CACHE
Change-Id: I1b4e8482817c422850aed7359beaf7174b55445e
diff --git a/libs/binder/BackendUnifiedServiceManager.h b/libs/binder/BackendUnifiedServiceManager.h
index 387eb35..47b2ec9 100644
--- a/libs/binder/BackendUnifiedServiceManager.h
+++ b/libs/binder/BackendUnifiedServiceManager.h
@@ -18,9 +18,87 @@
 #include <android/os/BnServiceManager.h>
 #include <android/os/IServiceManager.h>
 #include <binder/IPCThreadState.h>
+#include <map>
+#include <memory>
 
 namespace android {
 
+class BinderCacheWithInvalidation
+      : public std::enable_shared_from_this<BinderCacheWithInvalidation> {
+    class BinderInvalidation : public IBinder::DeathRecipient {
+    public:
+        BinderInvalidation(std::weak_ptr<BinderCacheWithInvalidation> cache, const std::string& key)
+              : mCache(cache), mKey(key) {}
+
+        void binderDied(const wp<IBinder>& who) override {
+            sp<IBinder> binder = who.promote();
+            if (std::shared_ptr<BinderCacheWithInvalidation> cache = mCache.lock()) {
+                cache->removeItem(mKey, binder);
+            } else {
+                ALOGI("Binder Cache pointer expired: %s", mKey.c_str());
+            }
+        }
+
+    private:
+        std::weak_ptr<BinderCacheWithInvalidation> mCache;
+        std::string mKey;
+    };
+    struct Entry {
+        sp<IBinder> service;
+        sp<BinderInvalidation> deathRecipient;
+    };
+
+public:
+    sp<IBinder> getItem(const std::string& key) const {
+        std::lock_guard<std::mutex> lock(mCacheMutex);
+
+        if (auto it = mCache.find(key); it != mCache.end()) {
+            return it->second.service;
+        }
+        return nullptr;
+    }
+
+    bool removeItem(const std::string& key, const sp<IBinder>& who) {
+        std::lock_guard<std::mutex> lock(mCacheMutex);
+        if (auto it = mCache.find(key); it != mCache.end()) {
+            if (it->second.service == who) {
+                status_t result = who->unlinkToDeath(it->second.deathRecipient);
+                if (result != DEAD_OBJECT) {
+                    ALOGW("Unlinking to dead binder resulted in: %d", result);
+                }
+                mCache.erase(key);
+                return true;
+            }
+        }
+        return false;
+    }
+
+    binder::Status setItem(const std::string& key, const sp<IBinder>& item) {
+        sp<BinderInvalidation> deathRecipient =
+                sp<BinderInvalidation>::make(shared_from_this(), key);
+
+        // linkToDeath if binder is a remote binder.
+        if (item->localBinder() == nullptr) {
+            status_t status = item->linkToDeath(deathRecipient);
+            if (status != android::OK) {
+                ALOGE("Failed to linkToDeath binder for service %s. Error: %d", key.c_str(),
+                      status);
+                return binder::Status::fromStatusT(status);
+            }
+        }
+        std::lock_guard<std::mutex> lock(mCacheMutex);
+        Entry entry = {.service = item, .deathRecipient = deathRecipient};
+        mCache[key] = entry;
+        return binder::Status::ok();
+    }
+
+    bool isClientSideCachingEnabled(const std::string& serviceName);
+
+private:
+    std::map<std::string, Entry> mCache;
+    mutable std::mutex mCacheMutex;
+};
+
 class BackendUnifiedServiceManager : public android::os::BnServiceManager {
 public:
     explicit BackendUnifiedServiceManager(const sp<os::IServiceManager>& impl);
@@ -58,9 +136,12 @@
     }
 
 private:
+    std::shared_ptr<BinderCacheWithInvalidation> mCacheForGetService;
     sp<os::IServiceManager> mTheRealServiceManager;
     binder::Status toBinderService(const ::std::string& name, const os::Service& in,
                                    os::Service* _out);
+    binder::Status updateCache(const std::string& serviceName, const os::Service& service);
+    bool returnIfCached(const std::string& serviceName, os::Service* _out);
 };
 
 sp<BackendUnifiedServiceManager> getBackendUnifiedServiceManager();