Merge "Refactor A11yInteractionClient so every service has its own cache"
diff --git a/core/java/android/accessibilityservice/AccessibilityService.java b/core/java/android/accessibilityservice/AccessibilityService.java
index 0f852b4..09af72d 100644
--- a/core/java/android/accessibilityservice/AccessibilityService.java
+++ b/core/java/android/accessibilityservice/AccessibilityService.java
@@ -2071,7 +2071,7 @@
             try {
                 connection.setServiceInfo(mInfo);
                 mInfo = null;
-                AccessibilityInteractionClient.getInstance(this).clearCache();
+                AccessibilityInteractionClient.getInstance(this).clearCache(mConnectionId);
             } catch (RemoteException re) {
                 Log.w(LOG_TAG, "Error while setting AccessibilityServiceInfo", re);
                 re.rethrowFromSystemServer();
@@ -2421,7 +2421,7 @@
                     if (event != null) {
                         // Send the event to AccessibilityCache via AccessibilityInteractionClient
                         AccessibilityInteractionClient.getInstance(mContext).onAccessibilityEvent(
-                                event);
+                                event, mConnectionId);
                         if (serviceWantsEvent
                                 && (mConnectionId != AccessibilityInteractionClient.NO_ID)) {
                             // Send the event to AccessibilityService
@@ -2451,7 +2451,7 @@
                     args.recycle();
                     if (connection != null) {
                         AccessibilityInteractionClient.getInstance(mContext).addConnection(
-                                mConnectionId, connection);
+                                mConnectionId, connection, /*initializeCache=*/true);
                         if (mContext != null) {
                             try {
                                 connection.setAttributionTag(mContext.getAttributionTag());
@@ -2466,7 +2466,8 @@
                         AccessibilityInteractionClient.getInstance(mContext).removeConnection(
                                 mConnectionId);
                         mConnectionId = AccessibilityInteractionClient.NO_ID;
-                        AccessibilityInteractionClient.getInstance(mContext).clearCache();
+                        AccessibilityInteractionClient.getInstance(mContext)
+                                .clearCache(mConnectionId);
                         mCallback.init(AccessibilityInteractionClient.NO_ID, null);
                     }
                     return;
@@ -2478,7 +2479,7 @@
                     return;
                 }
                 case DO_CLEAR_ACCESSIBILITY_CACHE: {
-                    AccessibilityInteractionClient.getInstance(mContext).clearCache();
+                    AccessibilityInteractionClient.getInstance(mContext).clearCache(mConnectionId);
                     return;
                 }
                 case DO_ON_KEY_EVENT: {
diff --git a/core/java/android/app/UiAutomation.java b/core/java/android/app/UiAutomation.java
index 828b171..58ded71 100644
--- a/core/java/android/app/UiAutomation.java
+++ b/core/java/android/app/UiAutomation.java
@@ -638,7 +638,7 @@
         final IAccessibilityServiceConnection connection;
         synchronized (mLock) {
             throwIfNotConnectedLocked();
-            AccessibilityInteractionClient.getInstance().clearCache();
+            AccessibilityInteractionClient.getInstance().clearCache(mConnectionId);
             connection = AccessibilityInteractionClient.getInstance()
                     .getConnection(mConnectionId);
         }
diff --git a/core/java/android/view/AccessibilityInteractionController.java b/core/java/android/view/AccessibilityInteractionController.java
index 9cd8313..de56d3a 100644
--- a/core/java/android/view/AccessibilityInteractionController.java
+++ b/core/java/android/view/AccessibilityInteractionController.java
@@ -151,8 +151,7 @@
             if (interrogatingPid == mMyProcessId && interrogatingTid == mMyLooperThreadId
                     && mHandler.hasAccessibilityCallback(message)) {
                 AccessibilityInteractionClient.getInstanceForThread(
-                        interrogatingTid, /* initializeCache= */true)
-                        .setSameThreadMessage(message);
+                        interrogatingTid).setSameThreadMessage(message);
             } else {
                 // For messages without callback of interrogating client, just handle the
                 // message immediately if this is UI thread.
diff --git a/core/java/android/view/accessibility/AccessibilityInteractionClient.java b/core/java/android/view/accessibility/AccessibilityInteractionClient.java
index bc21488..dc4c59a 100644
--- a/core/java/android/view/accessibility/AccessibilityInteractionClient.java
+++ b/core/java/android/view/accessibility/AccessibilityInteractionClient.java
@@ -115,13 +115,13 @@
         from a window, mapping from windowId -> timestamp. */
     private static final SparseLongArray sScrollingWindows = new SparseLongArray();
 
-    private static AccessibilityCache sAccessibilityCache;
+    private static SparseArray<AccessibilityCache> sCaches = new SparseArray<>();
 
     private final AtomicInteger mInteractionIdCounter = new AtomicInteger();
 
     private final Object mInstanceLock = new Object();
 
-    private final AccessibilityManager  mAccessibilityManager;
+    private final AccessibilityManager mAccessibilityManager;
 
     private volatile int mInteractionId = -1;
     private volatile int mCallingUid = Process.INVALID_UID;
@@ -150,7 +150,37 @@
     @UnsupportedAppUsage()
     public static AccessibilityInteractionClient getInstance() {
         final long threadId = Thread.currentThread().getId();
-        return getInstanceForThread(threadId, true);
+        return getInstanceForThread(threadId);
+    }
+
+    /**
+     * <strong>Note:</strong> We keep one instance per interrogating thread since
+     * the instance contains state which can lead to undesired thread interleavings.
+     * We do not have a thread local variable since other threads should be able to
+     * look up the correct client knowing a thread id. See ViewRootImpl for details.
+     *
+     * @return The client for a given <code>threadId</code>.
+     */
+    public static AccessibilityInteractionClient getInstanceForThread(long threadId) {
+        synchronized (sStaticLock) {
+            AccessibilityInteractionClient client = sClients.get(threadId);
+            if (client == null) {
+                client = new AccessibilityInteractionClient();
+                sClients.put(threadId, client);
+            }
+            return client;
+        }
+    }
+
+    /**
+     * @return The client for the current thread.
+     */
+    public static AccessibilityInteractionClient getInstance(Context context) {
+        final long threadId = Thread.currentThread().getId();
+        if (context != null) {
+            return getInstanceForThread(threadId, context);
+        }
+        return getInstanceForThread(threadId);
     }
 
     /**
@@ -162,61 +192,11 @@
      * @return The client for a given <code>threadId</code>.
      */
     public static AccessibilityInteractionClient getInstanceForThread(long threadId,
-            boolean initializeCache) {
-        synchronized (sStaticLock) {
-            AccessibilityInteractionClient client = sClients.get(threadId);
-            if (client == null) {
-                if (Binder.getCallingUid() == Process.SYSTEM_UID) {
-                    // Don't initialize a cache for the system process
-                    client = new AccessibilityInteractionClient(false);
-                } else {
-                    client = new AccessibilityInteractionClient(initializeCache);
-                }
-                sClients.put(threadId, client);
-            }
-            return client;
-        }
-    }
-
-    /**
-     * @return The client for the current thread.
-     */
-    public static AccessibilityInteractionClient getInstance(Context context) {
-        return getInstance(/* initializeCache= */true, context);
-    }
-
-    /**
-     * @param initializeCache whether to initialize the cache in a new client instance
-     * @return The client for the current thread.
-     */
-    public static AccessibilityInteractionClient getInstance(boolean initializeCache,
             Context context) {
-        final long threadId = Thread.currentThread().getId();
-        if (context != null) {
-            return getInstanceForThread(threadId, initializeCache, context);
-        }
-        return getInstanceForThread(threadId, initializeCache);
-    }
-
-    /**
-     * <strong>Note:</strong> We keep one instance per interrogating thread since
-     * the instance contains state which can lead to undesired thread interleavings.
-     * We do not have a thread local variable since other threads should be able to
-     * look up the correct client knowing a thread id. See ViewRootImpl for details.
-     *
-     * @param initializeCache whether to initialize the cache in a new client instance
-     * @return The client for a given <code>threadId</code>.
-     */
-    public static AccessibilityInteractionClient getInstanceForThread(
-            long threadId, boolean initializeCache, Context context) {
         synchronized (sStaticLock) {
             AccessibilityInteractionClient client = sClients.get(threadId);
             if (client == null) {
-                if (Binder.getCallingUid() == Process.SYSTEM_UID) {
-                    client = new AccessibilityInteractionClient(false, context);
-                } else {
-                    client = new AccessibilityInteractionClient(initializeCache, context);
-                }
+                client = new AccessibilityInteractionClient(context);
                 sClients.put(threadId, client);
             }
             return client;
@@ -238,12 +218,30 @@
     /**
      * Adds a cached accessibility service connection.
      *
+     * Adds a cache if {@code initializeCache} is true
      * @param connectionId The connection id.
      * @param connection The connection.
+     * @param initializeCache whether to initialize a cache
      */
-    public static void addConnection(int connectionId, IAccessibilityServiceConnection connection) {
+    public static void addConnection(int connectionId, IAccessibilityServiceConnection connection,
+            boolean initializeCache) {
         synchronized (sConnectionCache) {
             sConnectionCache.put(connectionId, connection);
+            if (!initializeCache) {
+                return;
+            }
+            sCaches.put(connectionId, new AccessibilityCache(
+                        new AccessibilityCache.AccessibilityNodeRefresher()));
+        }
+    }
+
+    /**
+     * Gets a cached associated with the connection id if available.
+     *
+     */
+    public static AccessibilityCache getCache(int connectionId) {
+        synchronized (sConnectionCache) {
+            return sCaches.get(connectionId);
         }
     }
 
@@ -255,6 +253,7 @@
     public static void removeConnection(int connectionId) {
         synchronized (sConnectionCache) {
             sConnectionCache.remove(connectionId);
+            sCaches.remove(connectionId);
         }
     }
 
@@ -263,32 +262,21 @@
      * tests need to be able to verify this class's interactions with the cache
      */
     @VisibleForTesting
-    public static void setCache(AccessibilityCache cache) {
-        sAccessibilityCache = cache;
+    public static void setCache(int connectionId, AccessibilityCache cache) {
+        synchronized (sConnectionCache) {
+            sCaches.put(connectionId, cache);
+        }
     }
 
     private AccessibilityInteractionClient() {
         /* reducing constructor visibility */
-        this(true);
-    }
-
-    private AccessibilityInteractionClient(boolean initializeCache) {
-        initializeCache(initializeCache);
         mAccessibilityManager = null;
     }
 
-    private AccessibilityInteractionClient(boolean initializeCache, Context context) {
-        initializeCache(initializeCache);
+    private AccessibilityInteractionClient(Context context) {
         mAccessibilityManager = context.getSystemService(AccessibilityManager.class);
     }
 
-    private static void initializeCache(boolean initialize) {
-        if (initialize && sAccessibilityCache == null) {
-            sAccessibilityCache = new AccessibilityCache(
-                    new AccessibilityCache.AccessibilityNodeRefresher());
-        }
-    }
-
     /**
      * Sets the message to be processed if the interacted view hierarchy
      * and the interacting client are running in the same thread.
@@ -333,7 +321,7 @@
      *
      * @param connectionId The id of a connection for interacting with the system.
      * @param accessibilityWindowId A unique window id. Use
-     *     {@link android.view.accessibility.AccessibilityWindowInfo#ACTIVE_WINDOW_ID}
+     *     {@link AccessibilityWindowInfo#ACTIVE_WINDOW_ID}
      *     to query the currently active window.
      * @param bypassCache Whether to bypass the cache.
      * @return The {@link AccessibilityWindowInfo}.
@@ -344,21 +332,28 @@
             IAccessibilityServiceConnection connection = getConnection(connectionId);
             if (connection != null) {
                 AccessibilityWindowInfo window;
-                if (!bypassCache && sAccessibilityCache != null) {
-                    window = sAccessibilityCache.getWindow(accessibilityWindowId);
-                    if (window != null) {
+                AccessibilityCache cache = getCache(connectionId);
+                if (cache != null) {
+                    if (!bypassCache) {
+                        window = cache.getWindow(accessibilityWindowId);
+                        if (window != null) {
+                            if (DEBUG) {
+                                Log.i(LOG_TAG, "Window cache hit");
+                            }
+                            if (shouldTraceClient()) {
+                                logTraceClient(connection, "getWindow cache",
+                                        "connectionId=" + connectionId + ";accessibilityWindowId="
+                                                + accessibilityWindowId + ";bypassCache=false");
+                            }
+                            return window;
+                        }
                         if (DEBUG) {
-                            Log.i(LOG_TAG, "Window cache hit");
+                            Log.i(LOG_TAG, "Window cache miss");
                         }
-                        if (shouldTraceClient()) {
-                            logTraceClient(connection, "getWindow cache",
-                                    "connectionId=" + connectionId + ";accessibilityWindowId="
-                                    + accessibilityWindowId + ";bypassCache=false");
-                        }
-                        return window;
                     }
+                } else {
                     if (DEBUG) {
-                        Log.i(LOG_TAG, "Window cache miss");
+                        Log.w(LOG_TAG, "Cache is null for connection id: " + connectionId);
                     }
                 }
 
@@ -374,9 +369,9 @@
                             + bypassCache);
                 }
 
-                if (window != null && sAccessibilityCache != null) {
-                    if (!bypassCache) {
-                        sAccessibilityCache.addWindow(window);
+                if (window != null) {
+                    if (!bypassCache && cache != null) {
+                        cache.addWindow(window);
                     }
                     return window;
                 }
@@ -418,8 +413,9 @@
             IAccessibilityServiceConnection connection = getConnection(connectionId);
             if (connection != null) {
                 SparseArray<List<AccessibilityWindowInfo>> windows;
-                if (sAccessibilityCache != null) {
-                    windows = sAccessibilityCache.getWindowsOnAllDisplays();
+                AccessibilityCache cache = getCache(connectionId);
+                if (cache != null) {
+                    windows = cache.getWindowsOnAllDisplays();
                     if (windows != null) {
                         if (DEBUG) {
                             Log.i(LOG_TAG, "Windows cache hit");
@@ -433,6 +429,10 @@
                     if (DEBUG) {
                         Log.i(LOG_TAG, "Windows cache miss");
                     }
+                } else {
+                    if (DEBUG) {
+                        Log.w(LOG_TAG, "Cache is null for connection id: " + connectionId);
+                    }
                 }
 
                 long populationTimeStamp;
@@ -447,8 +447,8 @@
                     logTraceClient(connection, "getWindows", "connectionId=" + connectionId);
                 }
                 if (windows != null) {
-                    if (sAccessibilityCache != null) {
-                        sAccessibilityCache.setWindowsOnAllDisplays(windows, populationTimeStamp);
+                    if (cache != null) {
+                        cache.setWindowsOnAllDisplays(windows, populationTimeStamp);
                     }
                     return windows;
                 }
@@ -533,28 +533,35 @@
         try {
             IAccessibilityServiceConnection connection = getConnection(connectionId);
             if (connection != null) {
-                if (!bypassCache && sAccessibilityCache != null) {
-                    AccessibilityNodeInfo cachedInfo = sAccessibilityCache.getNode(
-                            accessibilityWindowId, accessibilityNodeId);
-                    if (cachedInfo != null) {
+                if (!bypassCache) {
+                    AccessibilityCache cache = getCache(connectionId);
+                    if (cache != null) {
+                        AccessibilityNodeInfo cachedInfo = cache.getNode(
+                                accessibilityWindowId, accessibilityNodeId);
+                        if (cachedInfo != null) {
+                            if (DEBUG) {
+                                Log.i(LOG_TAG, "Node cache hit for "
+                                        + idToString(accessibilityWindowId, accessibilityNodeId));
+                            }
+                            if (shouldTraceClient()) {
+                                logTraceClient(connection,
+                                        "findAccessibilityNodeInfoByAccessibilityId cache",
+                                        "connectionId=" + connectionId + ";accessibilityWindowId="
+                                                + accessibilityWindowId + ";accessibilityNodeId="
+                                                + accessibilityNodeId + ";bypassCache="
+                                                + bypassCache + ";prefetchFlags=" + prefetchFlags
+                                                + ";arguments=" + arguments);
+                            }
+                            return cachedInfo;
+                        }
                         if (DEBUG) {
-                            Log.i(LOG_TAG, "Node cache hit for "
+                            Log.i(LOG_TAG, "Node cache miss for "
                                     + idToString(accessibilityWindowId, accessibilityNodeId));
                         }
-                        if (shouldTraceClient()) {
-                            logTraceClient(connection,
-                                    "findAccessibilityNodeInfoByAccessibilityId cache",
-                                    "connectionId=" + connectionId + ";accessibilityWindowId="
-                                    + accessibilityWindowId + ";accessibilityNodeId="
-                                    + accessibilityNodeId + ";bypassCache=" + bypassCache
-                                    + ";prefetchFlags=" + prefetchFlags + ";arguments="
-                                    + arguments);
+                    } else {
+                        if (DEBUG) {
+                            Log.w(LOG_TAG, "Cache is null for connection id: " + connectionId);
                         }
-                        return cachedInfo;
-                    }
-                    if (DEBUG) {
-                        Log.i(LOG_TAG, "Node cache miss for "
-                                + idToString(accessibilityWindowId, accessibilityNodeId));
                     }
                 } else {
                     // No need to prefech nodes in bypass cache case.
@@ -758,19 +765,19 @@
     }
 
     /**
-     * Finds the {@link android.view.accessibility.AccessibilityNodeInfo} that has the
+     * Finds the {@link AccessibilityNodeInfo} that has the
      * specified focus type. The search is performed in the window whose id is specified
      * and starts from the node whose accessibility id is specified.
      *
      * @param connectionId The id of a connection for interacting with the system.
      * @param accessibilityWindowId A unique window id. Use
-     *     {@link android.view.accessibility.AccessibilityWindowInfo#ACTIVE_WINDOW_ID}
+     *     {@link AccessibilityWindowInfo#ACTIVE_WINDOW_ID}
      *     to query the currently active window. Use
-     *     {@link android.view.accessibility.AccessibilityWindowInfo#ANY_WINDOW_ID} to query all
+     *     {@link AccessibilityWindowInfo#ANY_WINDOW_ID} to query all
      *     windows
      * @param accessibilityNodeId A unique view id or virtual descendant id from
      *     where to start the search. Use
-     *     {@link android.view.accessibility.AccessibilityNodeInfo#ROOT_NODE_ID}
+     *     {@link AccessibilityNodeInfo#ROOT_NODE_ID}
      *     to start from the root.
      * @param focusType The focus type.
      * @return The accessibility focused {@link AccessibilityNodeInfo}.
@@ -781,8 +788,9 @@
         try {
             IAccessibilityServiceConnection connection = getConnection(connectionId);
             if (connection != null) {
-                if (sAccessibilityCache != null) {
-                    AccessibilityNodeInfo cachedInfo = sAccessibilityCache.getFocus(focusType,
+                AccessibilityCache cache = getCache(connectionId);
+                if (cache != null) {
+                    AccessibilityNodeInfo cachedInfo = cache.getFocus(focusType,
                             accessibilityNodeId, accessibilityWindowId);
                     if (cachedInfo != null) {
                         if (DEBUG) {
@@ -796,6 +804,10 @@
                         Log.i(LOG_TAG, "Focused node cache miss with "
                                 + idToString(accessibilityWindowId, accessibilityNodeId));
                     }
+                } else {
+                    if (DEBUG) {
+                        Log.w(LOG_TAG, "Cache is null for connection id: " + connectionId);
+                    }
                 }
                 final int interactionId = mInteractionIdCounter.getAndIncrement();
                 if (shouldTraceClient()) {
@@ -956,16 +968,25 @@
     }
 
     /**
-     * Clears the accessibility cache.
+     * Clears the cache associated with {@code connectionId}
+     * @param connectionId the connection id
+     * TODO(207417185): Modify UnsupportedAppUsage
      */
     @UnsupportedAppUsage()
-    public void clearCache() {
-        if (sAccessibilityCache != null) {
-            sAccessibilityCache.clear();
+    public void clearCache(int connectionId) {
+        AccessibilityCache cache = getCache(connectionId);
+        if (cache == null) {
+            return;
         }
+        cache.clear();
     }
 
-    public void onAccessibilityEvent(AccessibilityEvent event) {
+    /**
+     * Informs the cache associated with {@code connectionId} of {@code event}
+     * @param event the event
+     * @param connectionId the connection id
+     */
+    public void onAccessibilityEvent(AccessibilityEvent event, int connectionId) {
         switch (event.getEventType()) {
             case AccessibilityEvent.TYPE_VIEW_SCROLLED:
                 updateScrollingWindow(event.getWindowId(), SystemClock.uptimeMillis());
@@ -978,9 +999,14 @@
             default:
                 break;
         }
-        if (sAccessibilityCache != null) {
-            sAccessibilityCache.onAccessibilityEvent(event);
+        AccessibilityCache cache = getCache(connectionId);
+        if (cache == null) {
+            if (DEBUG) {
+                Log.w(LOG_TAG, "Cache is null for connection id: " + connectionId);
+            }
+            return;
         }
+        cache.onAccessibilityEvent(event);
     }
 
     /**
@@ -1216,8 +1242,15 @@
                 }
             }
             info.setSealed(true);
-            if (!bypassCache && sAccessibilityCache != null) {
-                sAccessibilityCache.add(info);
+            if (!bypassCache) {
+                AccessibilityCache cache = getCache(connectionId);
+                if (cache == null) {
+                    if (DEBUG) {
+                        Log.w(LOG_TAG, "Cache is null for connection id: " + connectionId);
+                    }
+                    return;
+                }
+                cache.add(info);
             }
         }
     }
diff --git a/core/tests/coretests/src/android/view/accessibility/AccessibilityCacheTest.java b/core/tests/coretests/src/android/view/accessibility/AccessibilityCacheTest.java
index 33c6a4b..e689b5d3 100644
--- a/core/tests/coretests/src/android/view/accessibility/AccessibilityCacheTest.java
+++ b/core/tests/coretests/src/android/view/accessibility/AccessibilityCacheTest.java
@@ -45,7 +45,6 @@
 
 import com.google.common.base.Throwables;
 
-import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -82,13 +81,6 @@
         mAccessibilityCache = new AccessibilityCache(mAccessibilityNodeRefresher);
     }
 
-    @After
-    public void tearDown() {
-        // Make sure we're recycling all of our window and node infos.
-        mAccessibilityCache.clear();
-        AccessibilityInteractionClient.getInstance().clearCache();
-    }
-
     @Test
     public void testEmptyCache_returnsNull() {
         assertNull(mAccessibilityCache.getNode(0, 0));
diff --git a/core/tests/coretests/src/android/view/accessibility/AccessibilityInteractionClientTest.java b/core/tests/coretests/src/android/view/accessibility/AccessibilityInteractionClientTest.java
index 7e1e7f4..3e061d2 100644
--- a/core/tests/coretests/src/android/view/accessibility/AccessibilityInteractionClientTest.java
+++ b/core/tests/coretests/src/android/view/accessibility/AccessibilityInteractionClientTest.java
@@ -17,6 +17,9 @@
 package android.view.accessibility;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.MockitoAnnotations.initMocks;
 
@@ -40,6 +43,8 @@
 @RunWith(AndroidJUnit4.class)
 public class AccessibilityInteractionClientTest {
     private static final int MOCK_CONNECTION_ID = 0xabcd;
+    private static final int MOCK_CONNECTION_OTHER_ID = 0xabce;
+
 
     private MockConnection mMockConnection = new MockConnection();
     @Mock private AccessibilityCache mMockCache;
@@ -47,8 +52,8 @@
     @Before
     public void setUp() {
         initMocks(this);
-        AccessibilityInteractionClient.setCache(mMockCache);
-        AccessibilityInteractionClient.addConnection(MOCK_CONNECTION_ID, mMockConnection);
+        AccessibilityInteractionClient.addConnection(
+                MOCK_CONNECTION_ID, mMockConnection, /*initializeCache=*/true);
     }
 
     /**
@@ -58,6 +63,7 @@
      */
     @Test
     public void findA11yNodeInfoByA11yId_whenBypassingCache_doesntTouchCache() {
+        AccessibilityInteractionClient.setCache(MOCK_CONNECTION_ID, mMockCache);
         final int windowId = 0x1234;
         final long accessibilityNodeId = 0x4321L;
         AccessibilityNodeInfo nodeFromConnection = AccessibilityNodeInfo.obtain();
@@ -71,6 +77,42 @@
         verifyZeroInteractions(mMockCache);
     }
 
+    @Test
+    public void getCache_differentConnections_returnsDifferentCaches() {
+        MockConnection mOtherMockConnection = new MockConnection();
+        AccessibilityInteractionClient.addConnection(
+                MOCK_CONNECTION_OTHER_ID, mOtherMockConnection, /*initializeCache=*/true);
+
+        AccessibilityCache firstCache = AccessibilityInteractionClient.getCache(MOCK_CONNECTION_ID);
+        AccessibilityCache secondCache = AccessibilityInteractionClient.getCache(
+                MOCK_CONNECTION_OTHER_ID);
+        assertNotEquals(firstCache, secondCache);
+    }
+
+    @Test
+    public void getCache_addConnectionWithoutCache_returnsNullCache() {
+        // Need to first remove from process cache
+        AccessibilityInteractionClient.removeConnection(MOCK_CONNECTION_OTHER_ID);
+
+        MockConnection mOtherMockConnection = new MockConnection();
+        AccessibilityInteractionClient.addConnection(
+                MOCK_CONNECTION_OTHER_ID, mOtherMockConnection, /*initializeCache=*/false);
+
+        AccessibilityCache cache = AccessibilityInteractionClient.getCache(
+                MOCK_CONNECTION_OTHER_ID);
+        assertNull(cache);
+    }
+
+    @Test
+    public void getCache_removeConnection_returnsNull() {
+        AccessibilityCache cache = AccessibilityInteractionClient.getCache(MOCK_CONNECTION_ID);
+        assertNotNull(cache);
+
+        AccessibilityInteractionClient.removeConnection(MOCK_CONNECTION_ID);
+        cache = AccessibilityInteractionClient.getCache(MOCK_CONNECTION_ID);
+        assertNull(cache);
+    }
+
     private static class MockConnection extends AccessibilityServiceConnectionImpl {
         AccessibilityNodeInfo mInfoToReturn;
 
diff --git a/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java b/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java
index 572cfdc..f3a5d35 100644
--- a/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java
+++ b/services/accessibility/java/com/android/server/accessibility/AccessibilityManagerService.java
@@ -3528,9 +3528,8 @@
 
             mConnectionId = service.mId;
 
-            mClient = AccessibilityInteractionClient.getInstance(/* initializeCache= */false,
-                    mContext);
-            mClient.addConnection(mConnectionId, service);
+            mClient = AccessibilityInteractionClient.getInstance(mContext);
+            mClient.addConnection(mConnectionId, service, /*initializeCache=*/false);
 
             //TODO: (multi-display) We need to support multiple displays.
             DisplayManager displayManager = (DisplayManager)