Add per-package lock in setAppFunctionEnabled.

Bug: 357551503
Test: CTS
Flag: android.app.appfunctions.flags.enable_app_function_manager
Change-Id: I3c46a6ef1cdb07e6e1a30a9d1417fac8c55574f3
diff --git a/services/appfunctions/java/com/android/server/appfunctions/AppFunctionManagerServiceImpl.java b/services/appfunctions/java/com/android/server/appfunctions/AppFunctionManagerServiceImpl.java
index c5fef19..2ee7561 100644
--- a/services/appfunctions/java/com/android/server/appfunctions/AppFunctionManagerServiceImpl.java
+++ b/services/appfunctions/java/com/android/server/appfunctions/AppFunctionManagerServiceImpl.java
@@ -65,12 +65,13 @@
 import com.android.internal.infra.AndroidFuture;
 import com.android.internal.util.DumpUtils;
 import com.android.server.SystemService.TargetUser;
-import com.android.server.appfunctions.RemoteServiceCaller.RunServiceCallCallback;
-import com.android.server.appfunctions.RemoteServiceCaller.ServiceUsageCompleteListener;
 
 import java.io.FileDescriptor;
 import java.io.PrintWriter;
+import java.util.Collections;
+import java.util.Map;
 import java.util.Objects;
+import java.util.WeakHashMap;
 import java.util.concurrent.CompletionException;
 import java.util.concurrent.Executor;
 
@@ -83,7 +84,8 @@
     private final ServiceHelper mInternalServiceHelper;
     private final ServiceConfig mServiceConfig;
     private final Context mContext;
-    private final Object mLock = new Object();
+    private final Map<String, Object> mLocks = new WeakHashMap<>();
+
 
     public AppFunctionManagerServiceImpl(@NonNull Context context) {
         this(
@@ -321,9 +323,7 @@
         THREAD_POOL_EXECUTOR.execute(
                 () -> {
                     try {
-                        // TODO(357551503): Instead of holding a global lock, hold a per-package
-                        //  lock.
-                        synchronized (mLock) {
+                        synchronized (getLockForPackage(callingPackage)) {
                             setAppFunctionEnabledInternalLocked(
                                     callingPackage, functionIdentifier, userHandle, enabledState);
                         }
@@ -351,7 +351,7 @@
      * process.
      */
     @WorkerThread
-    @GuardedBy("mLock")
+    @GuardedBy("getLockForPackage(callingPackage)")
     private void setAppFunctionEnabledInternalLocked(
             @NonNull String callingPackage,
             @NonNull String functionIdentifier,
@@ -545,6 +545,26 @@
                                     });
         }
     }
+    /**
+     * Retrieves the lock object associated with the given package name.
+     *
+     * This method returns the lock object from the {@code mLocks} map if it exists.
+     * If no lock is found for the given package name, a new lock object is created,
+     * stored in the map, and returned.
+     */
+    @VisibleForTesting
+    @NonNull
+    Object getLockForPackage(String callingPackage) {
+        // Synchronized the access to mLocks to prevent race condition.
+        synchronized (mLocks) {
+            // By using a WeakHashMap, we allow the garbage collector to reclaim memory by removing
+            // entries associated with unused callingPackage keys. Therefore, we remove the null
+            // values before getting/computing a new value. The goal is to not let the size of this
+            // map grow without an upper bound.
+            mLocks.values().removeAll(Collections.singleton(null)); // Remove null values
+            return mLocks.computeIfAbsent(callingPackage, k -> new Object());
+        }
+    }
 
     private static class AppFunctionMetadataObserver implements ObserverCallback {
         @Nullable private final MetadataSyncAdapter mPerUserMetadataSyncAdapter;
diff --git a/services/tests/appfunctions/Android.bp b/services/tests/appfunctions/Android.bp
index c841643..836f90b 100644
--- a/services/tests/appfunctions/Android.bp
+++ b/services/tests/appfunctions/Android.bp
@@ -36,7 +36,9 @@
         "androidx.test.core",
         "androidx.test.runner",
         "androidx.test.ext.truth",
+        "androidx.core_core-ktx",
         "kotlin-test",
+        "kotlinx_coroutines_test",
         "platform-test-annotations",
         "services.appfunctions",
         "servicestests-core-utils",
diff --git a/services/tests/appfunctions/src/com/android/server/appfunctions/AppFunctionManagerServiceImplTest.kt b/services/tests/appfunctions/src/com/android/server/appfunctions/AppFunctionManagerServiceImplTest.kt
new file mode 100644
index 0000000..a69e902
--- /dev/null
+++ b/services/tests/appfunctions/src/com/android/server/appfunctions/AppFunctionManagerServiceImplTest.kt
@@ -0,0 +1,89 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.appfunctions
+
+import android.app.appfunctions.flags.Flags
+import android.content.Context
+import android.platform.test.annotations.RequiresFlagsEnabled
+import android.platform.test.flag.junit.CheckFlagsRule
+import android.platform.test.flag.junit.DeviceFlagsValueProvider
+import androidx.test.core.app.ApplicationProvider
+import androidx.test.ext.junit.runners.AndroidJUnit4
+import com.google.common.truth.Truth.assertThat
+import kotlinx.coroutines.test.runTest
+import org.junit.Ignore
+import org.junit.Rule
+import org.junit.Test
+import org.junit.runner.RunWith
+
+@RunWith(AndroidJUnit4::class)
+@RequiresFlagsEnabled(Flags.FLAG_ENABLE_APP_FUNCTION_MANAGER)
+class AppFunctionManagerServiceImplTest {
+    @get:Rule
+    val checkFlagsRule: CheckFlagsRule = DeviceFlagsValueProvider.createCheckFlagsRule()
+
+    private val context: Context
+        get() = ApplicationProvider.getApplicationContext()
+
+    private val serviceImpl = AppFunctionManagerServiceImpl(context)
+
+    @Test
+    fun testGetLockForPackage_samePackage() {
+        val packageName = "com.example.app"
+        val lock1 = serviceImpl.getLockForPackage(packageName)
+        val lock2 = serviceImpl.getLockForPackage(packageName)
+
+        // Assert that the same lock object is returned for the same package name
+        assertThat(lock1).isEqualTo(lock2)
+    }
+
+    @Test
+    fun testGetLockForPackage_differentPackages() {
+        val packageName1 = "com.example.app1"
+        val packageName2 = "com.example.app2"
+        val lock1 = serviceImpl.getLockForPackage(packageName1)
+        val lock2 = serviceImpl.getLockForPackage(packageName2)
+
+        // Assert that different lock objects are returned for different package names
+        assertThat(lock1).isNotEqualTo(lock2)
+    }
+
+    @Ignore("Hard to deterministically trigger the garbage collector.")
+    @Test
+    fun testWeakReference_garbageCollected_differentLockAfterGC() = runTest {
+        // Create a large number of temporary objects to put pressure on the GC
+        val tempObjects = MutableList<Any?>(10000000) { Any() }
+        var callingPackage: String? = "com.example.app"
+        var lock1: Any? = serviceImpl.getLockForPackage(callingPackage)
+        callingPackage = null // Set the key to null
+        val lock1Hash = lock1.hashCode()
+        lock1 = null
+
+        // Create memory pressure
+        repeat(3) {
+            for (i in 1..100) {
+                "a".repeat(10000)
+            }
+            System.gc() // Suggest garbage collection
+            System.runFinalization()
+        }
+        // Get the lock again - it should be a different object now
+        val lock2 = serviceImpl.getLockForPackage("com.example.app")
+        // Assert that the lock objects are different
+        assertThat(lock1Hash).isNotEqualTo(lock2.hashCode())
+    }
+}