Refactor A11yInteractionClient so every service has its own cache

- Maintain a static collection of caches (sCaches). Lock is
sConnectionCache
- Initialize a cache when a connection is added. (client.addConnection)
- Remove the cache when a connection is removed. (client.removeConnection)
- System server should not have its cache initialized, can remove
parameter in AccessibiltyInteractionClient constructor

I tried to maintain a local reference (mAccessibilityCache) to the
service cache for each client to avoid accessing the static
collection for each request. But since the process is not immediately
killed when the service is turned off, ensuring resources are cleared
correctly is big time investment that can be done later if desired.

Bug: 192489177
Test: Build and use talkback, do the slew of tests.

atest CtsAccessibilityServiceTestCases  CtsAccessibilityTestCases
CtsUiAutomationTestCases
FrameworksServicesTests:com.android.server.accessibility
FrameworksCoreTests:com.android.internal.accessibility
FrameworksCoreTests:android.view.accessibiliity

Change-Id: Ic30f6915c7b3c186a512ed8b410bf6e842eff5fb
diff --git a/core/java/android/accessibilityservice/AccessibilityService.java b/core/java/android/accessibilityservice/AccessibilityService.java
index 0f852b4..09af72d 100644
--- a/core/java/android/accessibilityservice/AccessibilityService.java
+++ b/core/java/android/accessibilityservice/AccessibilityService.java
@@ -2071,7 +2071,7 @@
             try {
                 connection.setServiceInfo(mInfo);
                 mInfo = null;
-                AccessibilityInteractionClient.getInstance(this).clearCache();
+                AccessibilityInteractionClient.getInstance(this).clearCache(mConnectionId);
             } catch (RemoteException re) {
                 Log.w(LOG_TAG, "Error while setting AccessibilityServiceInfo", re);
                 re.rethrowFromSystemServer();
@@ -2421,7 +2421,7 @@
                     if (event != null) {
                         // Send the event to AccessibilityCache via AccessibilityInteractionClient
                         AccessibilityInteractionClient.getInstance(mContext).onAccessibilityEvent(
-                                event);
+                                event, mConnectionId);
                         if (serviceWantsEvent
                                 && (mConnectionId != AccessibilityInteractionClient.NO_ID)) {
                             // Send the event to AccessibilityService
@@ -2451,7 +2451,7 @@
                     args.recycle();
                     if (connection != null) {
                         AccessibilityInteractionClient.getInstance(mContext).addConnection(
-                                mConnectionId, connection);
+                                mConnectionId, connection, /*initializeCache=*/true);
                         if (mContext != null) {
                             try {
                                 connection.setAttributionTag(mContext.getAttributionTag());
@@ -2466,7 +2466,8 @@
                         AccessibilityInteractionClient.getInstance(mContext).removeConnection(
                                 mConnectionId);
                         mConnectionId = AccessibilityInteractionClient.NO_ID;
-                        AccessibilityInteractionClient.getInstance(mContext).clearCache();
+                        AccessibilityInteractionClient.getInstance(mContext)
+                                .clearCache(mConnectionId);
                         mCallback.init(AccessibilityInteractionClient.NO_ID, null);
                     }
                     return;
@@ -2478,7 +2479,7 @@
                     return;
                 }
                 case DO_CLEAR_ACCESSIBILITY_CACHE: {
-                    AccessibilityInteractionClient.getInstance(mContext).clearCache();
+                    AccessibilityInteractionClient.getInstance(mContext).clearCache(mConnectionId);
                     return;
                 }
                 case DO_ON_KEY_EVENT: {
diff --git a/core/java/android/app/UiAutomation.java b/core/java/android/app/UiAutomation.java
index 828b171..58ded71 100644
--- a/core/java/android/app/UiAutomation.java
+++ b/core/java/android/app/UiAutomation.java
@@ -638,7 +638,7 @@
         final IAccessibilityServiceConnection connection;
         synchronized (mLock) {
             throwIfNotConnectedLocked();
-            AccessibilityInteractionClient.getInstance().clearCache();
+            AccessibilityInteractionClient.getInstance().clearCache(mConnectionId);
             connection = AccessibilityInteractionClient.getInstance()
                     .getConnection(mConnectionId);
         }
diff --git a/core/java/android/view/AccessibilityInteractionController.java b/core/java/android/view/AccessibilityInteractionController.java
index 9cd8313..de56d3a 100644
--- a/core/java/android/view/AccessibilityInteractionController.java
+++ b/core/java/android/view/AccessibilityInteractionController.java
@@ -151,8 +151,7 @@
             if (interrogatingPid == mMyProcessId && interrogatingTid == mMyLooperThreadId
                     && mHandler.hasAccessibilityCallback(message)) {
                 AccessibilityInteractionClient.getInstanceForThread(
-                        interrogatingTid, /* initializeCache= */true)
-                        .setSameThreadMessage(message);
+                        interrogatingTid).setSameThreadMessage(message);
             } else {
                 // For messages without callback of interrogating client, just handle the
                 // message immediately if this is UI thread.
diff --git a/core/java/android/view/accessibility/AccessibilityInteractionClient.java b/core/java/android/view/accessibility/AccessibilityInteractionClient.java
index bc21488..dc4c59a 100644
--- a/core/java/android/view/accessibility/AccessibilityInteractionClient.java
+++ b/core/java/android/view/accessibility/AccessibilityInteractionClient.java
@@ -115,13 +115,13 @@
         from a window, mapping from windowId -> timestamp. */
     private static final SparseLongArray sScrollingWindows = new SparseLongArray();
 
-    private static AccessibilityCache sAccessibilityCache;
+    private static SparseArray<AccessibilityCache> sCaches = new SparseArray<>();
 
     private final AtomicInteger mInteractionIdCounter = new AtomicInteger();
 
     private final Object mInstanceLock = new Object();
 
-    private final AccessibilityManager  mAccessibilityManager;
+    private final AccessibilityManager mAccessibilityManager;
 
     private volatile int mInteractionId = -1;
     private volatile int mCallingUid = Process.INVALID_UID;
@@ -150,7 +150,37 @@
     @UnsupportedAppUsage()
     public static AccessibilityInteractionClient getInstance() {
         final long threadId = Thread.currentThread().getId();
-        return getInstanceForThread(threadId, true);
+        return getInstanceForThread(threadId);
+    }
+
+    /**
+     * <strong>Note:</strong> We keep one instance per interrogating thread since
+     * the instance contains state which can lead to undesired thread interleavings.
+     * We do not have a thread local variable since other threads should be able to
+     * look up the correct client knowing a thread id. See ViewRootImpl for details.
+     *
+     * @return The client for a given <code>threadId</code>.
+     */
+    public static AccessibilityInteractionClient getInstanceForThread(long threadId) {
+        synchronized (sStaticLock) {
+            AccessibilityInteractionClient client = sClients.get(threadId);
+            if (client == null) {
+                client = new AccessibilityInteractionClient();
+                sClients.put(threadId, client);
+            }
+            return client;
+        }
+    }
+
+    /**
+     * @return The client for the current thread.
+     */
+    public static AccessibilityInteractionClient getInstance(Context context) {
+        final long threadId = Thread.currentThread().getId();
+        if (context != null) {
+            return getInstanceForThread(threadId, context);
+        }
+        return getInstanceForThread(threadId);
     }
 
     /**
@@ -162,61 +192,11 @@
      * @return The client for a given <code>threadId</code>.
      */
     public static AccessibilityInteractionClient getInstanceForThread(long threadId,
-            boolean initializeCache) {
-        synchronized (sStaticLock) {
-            AccessibilityInteractionClient client = sClients.get(threadId);
-            if (client == null) {
-                if (Binder.getCallingUid() == Process.SYSTEM_UID) {
-                    // Don't initialize a cache for the system process
-                    client = new AccessibilityInteractionClient(false);
-                } else {
-                    client = new AccessibilityInteractionClient(initializeCache);
-                }
-                sClients.put(threadId, client);
-            }
-            return client;
-        }
-    }
-
-    /**
-     * @return The client for the current thread.
-     */
-    public static AccessibilityInteractionClient getInstance(Context context) {
-        return getInstance(/* initializeCache= */true, context);
-    }
-
-    /**
-     * @param initializeCache whether to initialize the cache in a new client instance
-     * @return The client for the current thread.
-     */
-    public static AccessibilityInteractionClient getInstance(boolean initializeCache,
             Context context) {
-        final long threadId = Thread.currentThread().getId();
-        if (context != null) {
-            return getInstanceForThread(threadId, initializeCache, context);
-        }
-        return getInstanceForThread(threadId, initializeCache);
-    }
-
-    /**
-     * <strong>Note:</strong> We keep one instance per interrogating thread since
-     * the instance contains state which can lead to undesired thread interleavings.
-     * We do not have a thread local variable since other threads should be able to
-     * look up the correct client knowing a thread id. See ViewRootImpl for details.
-     *
-     * @param initializeCache whether to initialize the cache in a new client instance
-     * @return The client for a given <code>threadId</code>.
-     */
-    public static AccessibilityInteractionClient getInstanceForThread(
-            long threadId, boolean initializeCache, Context context) {
         synchronized (sStaticLock) {
             AccessibilityInteractionClient client = sClients.get(threadId);
             if (client == null) {
-                if (Binder.getCallingUid() == Process.SYSTEM_UID) {
-                    client = new AccessibilityInteractionClient(false, context);
-                } else {
-                    client = new AccessibilityInteractionClient(initializeCache, context);
-                }
+                client = new AccessibilityInteractionClient(context);
                 sClients.put(threadId, client);
             }
             return client;
@@ -238,12 +218,30 @@
     /**
      * Adds a cached accessibility service connection.
      *
+     * Adds a cache if {@code initializeCache} is true
      * @param connectionId The connection id.
      * @param connection The connection.
+     * @param initializeCache whether to initialize a cache
      */
-    public static void addConnection(int connectionId, IAccessibilityServiceConnection connection) {
+    public static void addConnection(int connectionId, IAccessibilityServiceConnection connection,
+            boolean initializeCache) {
         synchronized (sConnectionCache) {
             sConnectionCache.put(connectionId, connection);
+            if (!initializeCache) {
+                return;
+            }
+            sCaches.put(connectionId, new AccessibilityCache(
+                        new AccessibilityCache.AccessibilityNodeRefresher()));
+        }
+    }
+
+    /**
+     * Gets a cached associated with the connection id if available.
+     *
+     */
+    public static AccessibilityCache getCache(int connectionId) {
+        synchronized (sConnectionCache) {
+            return sCaches.get(connectionId);
         }
     }
 
@@ -255,6 +253,7 @@
     public static void removeConnection(int connectionId) {
         synchronized (sConnectionCache) {
             sConnectionCache.remove(connectionId);
+            sCaches.remove(connectionId);
         }
     }
 
@@ -263,32 +262,21 @@
      * tests need to be able to verify this class's interactions with the cache
      */
     @VisibleForTesting
-    public static void setCache(AccessibilityCache cache) {
-        sAccessibilityCache = cache;
+    public static void setCache(int connectionId, AccessibilityCache cache) {
+        synchronized (sConnectionCache) {
+            sCaches.put(connectionId, cache);
+        }
     }
 
     private AccessibilityInteractionClient() {
         /* reducing constructor visibility */
-        this(true);
-    }
-
-    private AccessibilityInteractionClient(boolean initializeCache) {
-        initializeCache(initializeCache);
         mAccessibilityManager = null;
     }
 
-    private AccessibilityInteractionClient(boolean initializeCache, Context context) {
-        initializeCache(initializeCache);
+    private AccessibilityInteractionClient(Context context) {
         mAccessibilityManager = context.getSystemService(AccessibilityManager.class);
     }
 
-    private static void initializeCache(boolean initialize) {
-        if (initialize && sAccessibilityCache == null) {
-            sAccessibilityCache = new AccessibilityCache(
-                    new AccessibilityCache.AccessibilityNodeRefresher());
-        }
-    }
-
     /**
      * Sets the message to be processed if the interacted view hierarchy
      * and the interacting client are running in the same thread.
@@ -333,7 +321,7 @@
      *
      * @param connectionId The id of a connection for interacting with the system.
      * @param accessibilityWindowId A unique window id. Use
-     *     {@link android.view.accessibility.AccessibilityWindowInfo#ACTIVE_WINDOW_ID}
+     *     {@link AccessibilityWindowInfo#ACTIVE_WINDOW_ID}
      *     to query the currently active window.
      * @param bypassCache Whether to bypass the cache.
      * @return The {@link AccessibilityWindowInfo}.
@@ -344,21 +332,28 @@
             IAccessibilityServiceConnection connection = getConnection(connectionId);
             if (connection != null) {
                 AccessibilityWindowInfo window;
-                if (!bypassCache && sAccessibilityCache != null) {
-                    window = sAccessibilityCache.getWindow(accessibilityWindowId);
-                    if (window != null) {
+                AccessibilityCache cache = getCache(connectionId);
+                if (cache != null) {
+                    if (!bypassCache) {
+                        window = cache.getWindow(accessibilityWindowId);
+                        if (window != null) {
+                            if (DEBUG) {
+                                Log.i(LOG_TAG, "Window cache hit");
+                            }
+                            if (shouldTraceClient()) {
+                                logTraceClient(connection, "getWindow cache",
+                                        "connectionId=" + connectionId + ";accessibilityWindowId="
+                                                + accessibilityWindowId + ";bypassCache=false");
+                            }
+                            return window;
+                        }
                         if (DEBUG) {
-                            Log.i(LOG_TAG, "Window cache hit");
+                            Log.i(LOG_TAG, "Window cache miss");
                         }
-                        if (shouldTraceClient()) {
-                            logTraceClient(connection, "getWindow cache",
-                                    "connectionId=" + connectionId + ";accessibilityWindowId="
-                                    + accessibilityWindowId + ";bypassCache=false");
-                        }
-                        return window;
                     }
+                } else {
                     if (DEBUG) {
-                        Log.i(LOG_TAG, "Window cache miss");
+                        Log.w(LOG_TAG, "Cache is null for connection id: " + connectionId);
                     }
                 }
 
@@ -374,9 +369,9 @@
                             + bypassCache);
                 }
 
-                if (window != null && sAccessibilityCache != null) {
-                    if (!bypassCache) {
-                        sAccessibilityCache.addWindow(window);
+                if (window != null) {
+                    if (!bypassCache && cache != null) {
+                        cache.addWindow(window);
                     }
                     return window;
                 }
@@ -418,8 +413,9 @@
             IAccessibilityServiceConnection connection = getConnection(connectionId);
             if (connection != null) {
                 SparseArray<List<AccessibilityWindowInfo>> windows;
-                if (sAccessibilityCache != null) {
-                    windows = sAccessibilityCache.getWindowsOnAllDisplays();
+                AccessibilityCache cache = getCache(connectionId);
+                if (cache != null) {
+                    windows = cache.getWindowsOnAllDisplays();
                     if (windows != null) {
                         if (DEBUG) {
                             Log.i(LOG_TAG, "Windows cache hit");
@@ -433,6 +429,10 @@
                     if (DEBUG) {
                         Log.i(LOG_TAG, "Windows cache miss");
                     }
+                } else {
+                    if (DEBUG) {
+                        Log.w(LOG_TAG, "Cache is null for connection id: " + connectionId);
+                    }
                 }
 
                 long populationTimeStamp;
@@ -447,8 +447,8 @@
                     logTraceClient(connection, "getWindows", "connectionId=" + connectionId);
                 }
                 if (windows != null) {
-                    if (sAccessibilityCache != null) {
-                        sAccessibilityCache.setWindowsOnAllDisplays(windows, populationTimeStamp);
+                    if (cache != null) {
+                        cache.setWindowsOnAllDisplays(windows, populationTimeStamp);
                     }
                     return windows;
                 }
@@ -533,28 +533,35 @@
         try {
             IAccessibilityServiceConnection connection = getConnection(connectionId);
             if (connection != null) {
-                if (!bypassCache && sAccessibilityCache != null) {
-                    AccessibilityNodeInfo cachedInfo = sAccessibilityCache.getNode(
-                            accessibilityWindowId, accessibilityNodeId);
-                    if (cachedInfo != null) {
+                if (!bypassCache) {
+                    AccessibilityCache cache = getCache(connectionId);
+                    if (cache != null) {
+                        AccessibilityNodeInfo cachedInfo = cache.getNode(
+                                accessibilityWindowId, accessibilityNodeId);
+                        if (cachedInfo != null) {
+                            if (DEBUG) {
+                                Log.i(LOG_TAG, "Node cache hit for "
+                                        + idToString(accessibilityWindowId, accessibilityNodeId));
+                            }
+                            if (shouldTraceClient()) {
+                                logTraceClient(connection,
+                                        "findAccessibilityNodeInfoByAccessibilityId cache",
+                                        "connectionId=" + connectionId + ";accessibilityWindowId="
+                                                + accessibilityWindowId + ";accessibilityNodeId="
+                                                + accessibilityNodeId + ";bypassCache="
+                                                + bypassCache + ";prefetchFlags=" + prefetchFlags
+                                                + ";arguments=" + arguments);
+                            }
+                            return cachedInfo;
+                        }
                         if (DEBUG) {
-                            Log.i(LOG_TAG, "Node cache hit for "
+                            Log.i(LOG_TAG, "Node cache miss for "
                                     + idToString(accessibilityWindowId, accessibilityNodeId));
                         }
-                        if (shouldTraceClient()) {
-                            logTraceClient(connection,
-                                    "findAccessibilityNodeInfoByAccessibilityId cache",
-                                    "connectionId=" + connectionId + ";accessibilityWindowId="
-                                    + accessibilityWindowId + ";accessibilityNodeId="
-                                    + accessibilityNodeId + ";bypassCache=" + bypassCache
-                                    + ";prefetchFlags=" + prefetchFlags + ";arguments="
-                                    + arguments);
+                    } else {
+                        if (DEBUG) {
+                            Log.w(LOG_TAG, "Cache is null for connection id: " + connectionId);
                         }
-                        return cachedInfo;
-                    }
-                    if (DEBUG) {
-                        Log.i(LOG_TAG, "Node cache miss for "
-                                + idToString(accessibilityWindowId, accessibilityNodeId));
                     }
                 } else {
                     // No need to prefech nodes in bypass cache case.
@@ -758,19 +765,19 @@
     }
 
     /**
-     * Finds the {@link android.view.accessibility.AccessibilityNodeInfo} that has the
+     * Finds the {@link AccessibilityNodeInfo} that has the
      * specified focus type. The search is performed in the window whose id is specified
      * and starts from the node whose accessibility id is specified.
      *
      * @param connectionId The id of a connection for interacting with the system.
      * @param accessibilityWindowId A unique window id. Use
-     *     {@link android.view.accessibility.AccessibilityWindowInfo#ACTIVE_WINDOW_ID}
+     *     {@link AccessibilityWindowInfo#ACTIVE_WINDOW_ID}
      *     to query the currently active window. Use
-     *     {@link android.view.accessibility.AccessibilityWindowInfo#ANY_WINDOW_ID} to query all
+     *     {@link AccessibilityWindowInfo#ANY_WINDOW_ID} to query all
      *     windows
      * @param accessibilityNodeId A unique view id or virtual descendant id from
      *     where to start the search. Use
-     *     {@link android.view.accessibility.AccessibilityNodeInfo#ROOT_NODE_ID}
+     *     {@link AccessibilityNodeInfo#ROOT_NODE_ID}
      *     to start from the root.
      * @param focusType The focus type.
      * @return The accessibility focused {@link AccessibilityNodeInfo}.
@@ -781,8 +788,9 @@
         try {
             IAccessibilityServiceConnection connection = getConnection(connectionId);
             if (connection != null) {
-                if (sAccessibilityCache != null) {
-                    AccessibilityNodeInfo cachedInfo = sAccessibilityCache.getFocus(focusType,
+                AccessibilityCache cache = getCache(connectionId);
+                if (cache != null) {
+                    AccessibilityNodeInfo cachedInfo = cache.getFocus(focusType,
                             accessibilityNodeId, accessibilityWindowId);
                     if (cachedInfo != null) {
                         if (DEBUG) {
@@ -796,6 +804,10 @@
                         Log.i(LOG_TAG, "Focused node cache miss with "
                                 + idToString(accessibilityWindowId, accessibilityNodeId));
                     }
+                } else {
+                    if (DEBUG) {
+                        Log.w(LOG_TAG, "Cache is null for connection id: " + connectionId);
+                    }
                 }
                 final int interactionId = mInteractionIdCounter.getAndIncrement();
                 if (shouldTraceClient()) {
@@ -956,16 +968,25 @@
     }
 
     /**
-     * Clears the accessibility cache.
+     * Clears the cache associated with {@code connectionId}
+     * @param connectionId the connection id
+     * TODO(207417185): Modify UnsupportedAppUsage
      */
     @UnsupportedAppUsage()
-    public void clearCache() {
-        if (sAccessibilityCache != null) {
-            sAccessibilityCache.clear();
+    public void clearCache(int connectionId) {
+        AccessibilityCache cache = getCache(connectionId);
+        if (cache == null) {
+            return;
         }
+        cache.clear();
     }
 
-    public void onAccessibilityEvent(AccessibilityEvent event) {
+    /**
+     * Informs the cache associated with {@code connectionId} of {@code event}
+     * @param event the event
+     * @param connectionId the connection id
+     */
+    public void onAccessibilityEvent(AccessibilityEvent event, int connectionId) {
         switch (event.getEventType()) {
             case AccessibilityEvent.TYPE_VIEW_SCROLLED:
                 updateScrollingWindow(event.getWindowId(), SystemClock.uptimeMillis());
@@ -978,9 +999,14 @@
             default:
                 break;
         }
-        if (sAccessibilityCache != null) {
-            sAccessibilityCache.onAccessibilityEvent(event);
+        AccessibilityCache cache = getCache(connectionId);
+        if (cache == null) {
+            if (DEBUG) {
+                Log.w(LOG_TAG, "Cache is null for connection id: " + connectionId);
+            }
+            return;
         }
+        cache.onAccessibilityEvent(event);
     }
 
     /**
@@ -1216,8 +1242,15 @@
                 }
             }
             info.setSealed(true);
-            if (!bypassCache && sAccessibilityCache != null) {
-                sAccessibilityCache.add(info);
+            if (!bypassCache) {
+                AccessibilityCache cache = getCache(connectionId);
+                if (cache == null) {
+                    if (DEBUG) {
+                        Log.w(LOG_TAG, "Cache is null for connection id: " + connectionId);
+                    }
+                    return;
+                }
+                cache.add(info);
             }
         }
     }
diff --git a/core/tests/coretests/src/android/view/accessibility/AccessibilityCacheTest.java b/core/tests/coretests/src/android/view/accessibility/AccessibilityCacheTest.java
index 33c6a4b..e689b5d3 100644
--- a/core/tests/coretests/src/android/view/accessibility/AccessibilityCacheTest.java
+++ b/core/tests/coretests/src/android/view/accessibility/AccessibilityCacheTest.java
@@ -45,7 +45,6 @@
 
 import com.google.common.base.Throwables;
 
-import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -82,13 +81,6 @@
         mAccessibilityCache = new AccessibilityCache(mAccessibilityNodeRefresher);
     }
 
-    @After
-    public void tearDown() {
-        // Make sure we're recycling all of our window and node infos.
-        mAccessibilityCache.clear();
-        AccessibilityInteractionClient.getInstance().clearCache();
-    }
-
     @Test
     public void testEmptyCache_returnsNull() {
         assertNull(mAccessibilityCache.getNode(0, 0));
diff --git a/core/tests/coretests/src/android/view/accessibility/AccessibilityInteractionClientTest.java b/core/tests/coretests/src/android/view/accessibility/AccessibilityInteractionClientTest.java
index 7e1e7f4..3e061d2 100644
--- a/core/tests/coretests/src/android/view/accessibility/AccessibilityInteractionClientTest.java
+++ b/core/tests/coretests/src/android/view/accessibility/AccessibilityInteractionClientTest.java
@@ -17,6 +17,9 @@
 package android.view.accessibility;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.MockitoAnnotations.initMocks;
 
@@ -40,6 +43,8 @@
 @RunWith(AndroidJUnit4.class)
 public class AccessibilityInteractionClientTest {
     private static final int MOCK_CONNECTION_ID = 0xabcd;
+    private static final int MOCK_CONNECTION_OTHER_ID = 0xabce;
+
 
     private MockConnection mMockConnection = new MockConnection();
     @Mock private AccessibilityCache mMockCache;
@@ -47,8 +52,8 @@
     @Before
     public void setUp() {
         initMocks(this);
-        AccessibilityInteractionClient.setCache(mMockCache);
-        AccessibilityInteractionClient.addConnection(MOCK_CONNECTION_ID, mMockConnection);
+        AccessibilityInteractionClient.addConnection(
+                MOCK_CONNECTION_ID, mMockConnection, /*initializeCache=*/true);
     }
 
     /**
@@ -58,6 +63,7 @@
      */
     @Test
     public void findA11yNodeInfoByA11yId_whenBypassingCache_doesntTouchCache() {
+        AccessibilityInteractionClient.setCache(MOCK_CONNECTION_ID, mMockCache);
         final int windowId = 0x1234;
         final long accessibilityNodeId = 0x4321L;
         AccessibilityNodeInfo nodeFromConnection = AccessibilityNodeInfo.obtain();
@@ -71,6 +77,42 @@
         verifyZeroInteractions(mMockCache);
     }
 
+    @Test
+    public void getCache_differentConnections_returnsDifferentCaches() {
+        MockConnection mOtherMockConnection = new MockConnection();
+        AccessibilityInteractionClient.addConnection(
+                MOCK_CONNECTION_OTHER_ID, mOtherMockConnection, /*initializeCache=*/true);
+
+        AccessibilityCache firstCache = AccessibilityInteractionClient.getCache(MOCK_CONNECTION_ID);
+        AccessibilityCache secondCache = AccessibilityInteractionClient.getCache(
+                MOCK_CONNECTION_OTHER_ID);
+        assertNotEquals(firstCache, secondCache);
+    }
+
+    @Test
+    public void getCache_addConnectionWithoutCache_returnsNullCache() {
+        // Need to first remove from process cache
+        AccessibilityInteractionClient.removeConnection(MOCK_CONNECTION_OTHER_ID);
+
+        MockConnection mOtherMockConnection = new MockConnection();
+        AccessibilityInteractionClient.addConnection(
+                MOCK_CONNECTION_OTHER_ID, mOtherMockConnection, /*initializeCache=*/false);
+
+        AccessibilityCache cache = AccessibilityInteractionClient.getCache(
+                MOCK_CONNECTION_OTHER_ID);
+        assertNull(cache);
+    }
+
+    @Test
+    public void getCache_removeConnection_returnsNull() {
+        AccessibilityCache cache = AccessibilityInteractionClient.getCache(MOCK_CONNECTION_ID);
+        assertNotNull(cache);
+
+        AccessibilityInteractionClient.removeConnection(MOCK_CONNECTION_ID);
+        cache = AccessibilityInteractionClient.getCache(MOCK_CONNECTION_ID);
+        assertNull(cache);
+    }
+
     private static class MockConnection extends AccessibilityServiceConnectionImpl {
         AccessibilityNodeInfo mInfoToReturn;
 
diff --git a/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java b/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java
index a11e264..d0b298f 100644
--- a/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java
+++ b/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java
@@ -3528,9 +3528,8 @@
 
             mConnectionId = service.mId;
 
-            mClient = AccessibilityInteractionClient.getInstance(/* initializeCache= */false,
-                    mContext);
-            mClient.addConnection(mConnectionId, service);
+            mClient = AccessibilityInteractionClient.getInstance(mContext);
+            mClient.addConnection(mConnectionId, service, /*initializeCache=*/false);
 
             //TODO: (multi-display) We need to support multiple displays.
             DisplayManager displayManager = (DisplayManager)