Enlarge pointer icon for magnification

When full screen magnification is enabled, pointer icons should be also
enlarged by the zoom factor.

This change adds an API in InputManagerInternal for accessibility to
update the scale factor when full screen magnification is enabled.

Currently there's only one caller, but it will be called from more
places in future changes.

Bug: 355734856
Test: Enable flag, and changing scale with slider or gesture will change pointer icons.
Test: PointerIconCacheTest
Test: atest com.android.server.accessibility.magnification
Flag: com.android.server.accessibility.magnification_enlarge_pointer

Change-Id: I985183b12d6a1c4a6aa4c53e4f462778aeb2e9ae
diff --git a/services/accessibility/java/com/android/server/accessibility/magnification/FullScreenMagnificationController.java b/services/accessibility/java/com/android/server/accessibility/magnification/FullScreenMagnificationController.java
index a77ba62..12c64c5 100644
--- a/services/accessibility/java/com/android/server/accessibility/magnification/FullScreenMagnificationController.java
+++ b/services/accessibility/java/com/android/server/accessibility/magnification/FullScreenMagnificationController.java
@@ -64,6 +64,7 @@
 import com.android.server.accessibility.AccessibilityManagerService;
 import com.android.server.accessibility.AccessibilityTraceManager;
 import com.android.server.accessibility.Flags;
+import com.android.server.input.InputManagerInternal;
 import com.android.server.wm.WindowManagerInternal;
 
 import java.util.ArrayList;
@@ -955,6 +956,7 @@
                         context,
                         traceManager,
                         LocalServices.getService(WindowManagerInternal.class),
+                        LocalServices.getService(InputManagerInternal.class),
                         new Handler(context.getMainLooper()),
                         context.getResources().getInteger(R.integer.config_longAnimTime)),
                 lock,
@@ -1640,6 +1642,8 @@
      */
     public void persistScale(int displayId) {
         final float scale = getScale(displayId);
+        notifyScaleForInput(displayId, scale);
+
         if (scale < MagnificationConstants.PERSISTED_SCALE_MIN_VALUE) {
             return;
         }
@@ -1691,6 +1695,20 @@
     }
 
     /**
+     * Notifies input manager that magnification scale changed non-transiently
+     * so that pointer cursor is scaled as well.
+     *
+     * @param displayId The logical display id.
+     * @param scale     The new scale factor.
+     */
+    public void notifyScaleForInput(int displayId, float scale) {
+        if (Flags.magnificationEnlargePointer()) {
+            mControllerCtx.getInputManager()
+                    .setAccessibilityPointerIconScaleFactor(displayId, scale);
+        }
+    }
+
+    /**
      * Resets all displays' magnification if last magnifying service is disabled.
      *
      * @param connectionId
@@ -2166,6 +2184,7 @@
         private final Context mContext;
         private final AccessibilityTraceManager mTrace;
         private final WindowManagerInternal mWindowManager;
+        private final InputManagerInternal mInputManager;
         private final Handler mHandler;
         private final Long mAnimationDuration;
 
@@ -2175,11 +2194,13 @@
         public ControllerContext(@NonNull Context context,
                 @NonNull AccessibilityTraceManager traceManager,
                 @NonNull WindowManagerInternal windowManager,
+                @NonNull InputManagerInternal inputManager,
                 @NonNull Handler handler,
                 long animationDuration) {
             mContext = context;
             mTrace = traceManager;
             mWindowManager = windowManager;
+            mInputManager = inputManager;
             mHandler = handler;
             mAnimationDuration = animationDuration;
         }
@@ -2209,6 +2230,14 @@
         }
 
         /**
+         * @return InputManagerInternal
+         */
+        @NonNull
+        public InputManagerInternal getInputManager() {
+            return mInputManager;
+        }
+
+        /**
          * @return Handler for main looper
          */
         @NonNull
diff --git a/services/core/java/com/android/server/input/InputManagerInternal.java b/services/core/java/com/android/server/input/InputManagerInternal.java
index 99f7f12..c888eef 100644
--- a/services/core/java/com/android/server/input/InputManagerInternal.java
+++ b/services/core/java/com/android/server/input/InputManagerInternal.java
@@ -262,4 +262,12 @@
      */
     public abstract void handleKeyGestureInKeyGestureController(int deviceId, int[] keycodes,
             int modifierState, @KeyGestureEvent.KeyGestureType int event);
+
+    /**
+     * Sets the magnification scale factor for pointer icons.
+     *
+     * @param displayId   the ID of the display where the new scale factor is applied.
+     * @param scaleFactor the new scale factor to be applied for pointer icons.
+     */
+    public abstract void setAccessibilityPointerIconScaleFactor(int displayId, float scaleFactor);
 }
diff --git a/services/core/java/com/android/server/input/InputManagerService.java b/services/core/java/com/android/server/input/InputManagerService.java
index 8acf583..98e5319 100644
--- a/services/core/java/com/android/server/input/InputManagerService.java
+++ b/services/core/java/com/android/server/input/InputManagerService.java
@@ -3506,6 +3506,11 @@
                 int modifierState, @KeyGestureEvent.KeyGestureType int gestureType) {
             mKeyGestureController.handleKeyGesture(deviceId, keycodes, modifierState, gestureType);
         }
+
+        @Override
+        public void setAccessibilityPointerIconScaleFactor(int displayId, float scaleFactor) {
+            InputManagerService.this.setAccessibilityPointerIconScaleFactor(displayId, scaleFactor);
+        }
     }
 
     @Override
@@ -3688,6 +3693,10 @@
         mPointerIconCache.setPointerScale(scale);
     }
 
+    void setAccessibilityPointerIconScaleFactor(int displayId, float scaleFactor) {
+        mPointerIconCache.setAccessibilityScaleFactor(displayId, scaleFactor);
+    }
+
     interface KeyboardBacklightControllerInterface {
         default void incrementKeyboardBacklight(int deviceId) {}
         default void decrementKeyboardBacklight(int deviceId) {}
diff --git a/services/core/java/com/android/server/input/PointerIconCache.java b/services/core/java/com/android/server/input/PointerIconCache.java
index 297cd68..e16031c 100644
--- a/services/core/java/com/android/server/input/PointerIconCache.java
+++ b/services/core/java/com/android/server/input/PointerIconCache.java
@@ -27,6 +27,7 @@
 import android.os.Handler;
 import android.util.Slog;
 import android.util.SparseArray;
+import android.util.SparseDoubleArray;
 import android.util.SparseIntArray;
 import android.view.ContextThemeWrapper;
 import android.view.Display;
@@ -34,6 +35,7 @@
 import android.view.PointerIcon;
 
 import com.android.internal.annotations.GuardedBy;
+import com.android.internal.annotations.VisibleForTesting;
 import com.android.server.UiThread;
 
 import java.util.Objects;
@@ -51,7 +53,7 @@
     private final NativeInputManagerService mNative;
 
     // We use the UI thread for loading pointer icons.
-    private final Handler mUiThreadHandler = UiThread.getHandler();
+    private final Handler mUiThreadHandler;
 
     @GuardedBy("mLoadedPointerIconsByDisplayAndType")
     private final SparseArray<SparseArray<PointerIcon>> mLoadedPointerIconsByDisplayAndType =
@@ -70,6 +72,9 @@
             POINTER_ICON_VECTOR_STYLE_STROKE_WHITE;
     @GuardedBy("mLoadedPointerIconsByDisplayAndType")
     private float mPointerIconScale = DEFAULT_POINTER_SCALE;
+    // Note that android doesn't have SparseFloatArray, so this falls back to use double instead.
+    @GuardedBy("mLoadedPointerIconsByDisplayAndType")
+    private final SparseDoubleArray mAccessibilityScaleFactorPerDisplay = new SparseDoubleArray();
 
     private final DisplayManager.DisplayListener mDisplayListener =
             new DisplayManager.DisplayListener() {
@@ -86,6 +91,7 @@
                         mLoadedPointerIconsByDisplayAndType.remove(displayId);
                         mDisplayContexts.remove(displayId);
                         mDisplayDensities.delete(displayId);
+                        mAccessibilityScaleFactorPerDisplay.delete(displayId);
                     }
                 }
 
@@ -96,8 +102,15 @@
             };
 
     /* package */ PointerIconCache(Context context, NativeInputManagerService nativeService) {
+        this(context, nativeService, UiThread.getHandler());
+    }
+
+    @VisibleForTesting
+    /* package */ PointerIconCache(Context context, NativeInputManagerService nativeService,
+            Handler handler) {
         mContext = context;
         mNative = nativeService;
+        mUiThreadHandler = handler;
     }
 
     public void systemRunning() {
@@ -134,6 +147,11 @@
         mUiThreadHandler.post(() -> handleSetPointerScale(scale));
     }
 
+    /** Set the scale for accessibility (magnification) for vector pointer icons. */
+    public void setAccessibilityScaleFactor(int displayId, float scaleFactor) {
+        mUiThreadHandler.post(() -> handleAccessibilityScaleFactor(displayId, scaleFactor));
+    }
+
     /**
      * Get a loaded system pointer icon. This will fetch the icon from the cache, or load it if
      * it isn't already cached.
@@ -155,8 +173,10 @@
                         /* force= */ true);
                 theme.applyStyle(PointerIcon.vectorStrokeStyleToResource(mPointerIconStrokeStyle),
                         /* force= */ true);
+                final float scale = mPointerIconScale
+                        * (float) mAccessibilityScaleFactorPerDisplay.get(displayId, 1f);
                 icon = PointerIcon.getLoadedSystemIcon(new ContextThemeWrapper(context, theme),
-                        type, mUseLargePointerIcons, mPointerIconScale);
+                        type, mUseLargePointerIcons, scale);
                 iconsByType.put(type, icon);
             }
             return Objects.requireNonNull(icon);
@@ -261,6 +281,19 @@
         mNative.reloadPointerIcons();
     }
 
+    @android.annotation.UiThread
+    private void handleAccessibilityScaleFactor(int displayId, float scale) {
+        synchronized (mLoadedPointerIconsByDisplayAndType) {
+            if (mAccessibilityScaleFactorPerDisplay.get(displayId, 1f) == scale) {
+                return;
+            }
+            mAccessibilityScaleFactorPerDisplay.put(displayId, scale);
+            // Clear cached icons on the display.
+            mLoadedPointerIconsByDisplayAndType.remove(displayId);
+        }
+        mNative.reloadPointerIcons();
+    }
+
     // Updates the cached display density for the given displayId, and returns true if
     // the cached density changed.
     @GuardedBy("mLoadedPointerIconsByDisplayAndType")
diff --git a/services/tests/servicestests/src/com/android/server/accessibility/magnification/FullScreenMagnificationControllerTest.java b/services/tests/servicestests/src/com/android/server/accessibility/magnification/FullScreenMagnificationControllerTest.java
index c4b4afd..76553ba 100644
--- a/services/tests/servicestests/src/com/android/server/accessibility/magnification/FullScreenMagnificationControllerTest.java
+++ b/services/tests/servicestests/src/com/android/server/accessibility/magnification/FullScreenMagnificationControllerTest.java
@@ -18,6 +18,7 @@
 
 import static android.accessibilityservice.MagnificationConfig.MAGNIFICATION_MODE_FULLSCREEN;
 
+import static com.android.server.accessibility.Flags.FLAG_MAGNIFICATION_ENLARGE_POINTER;
 import static com.android.server.accessibility.magnification.FullScreenMagnificationController.MagnificationInfoChangedCallback;
 import static com.android.server.accessibility.magnification.MockMagnificationConnection.TEST_DISPLAY;
 import static com.android.window.flags.Flags.FLAG_ALWAYS_DRAW_MAGNIFICATION_FULLSCREEN_BORDER;
@@ -76,6 +77,7 @@
 import com.android.server.accessibility.AccessibilityTraceManager;
 import com.android.server.accessibility.Flags;
 import com.android.server.accessibility.test.MessageCapturingHandler;
+import com.android.server.input.InputManagerInternal;
 import com.android.server.wm.WindowManagerInternal;
 import com.android.server.wm.WindowManagerInternal.MagnificationCallbacks;
 
@@ -126,6 +128,7 @@
     final Resources mMockResources = mock(Resources.class);
     final AccessibilityTraceManager mMockTraceManager = mock(AccessibilityTraceManager.class);
     final WindowManagerInternal mMockWindowManager = mock(WindowManagerInternal.class);
+    final InputManagerInternal mMockInputManager = mock(InputManagerInternal.class);
     private final MagnificationAnimationCallback mAnimationCallback = mock(
             MagnificationAnimationCallback.class);
     private final MagnificationInfoChangedCallback mRequestObserver = mock(
@@ -163,6 +166,7 @@
         when(mMockControllerCtx.getContext()).thenReturn(mMockContext);
         when(mMockControllerCtx.getTraceManager()).thenReturn(mMockTraceManager);
         when(mMockControllerCtx.getWindowManager()).thenReturn(mMockWindowManager);
+        when(mMockControllerCtx.getInputManager()).thenReturn(mMockInputManager);
         when(mMockControllerCtx.getHandler()).thenReturn(mMessageCapturingHandler);
         when(mMockControllerCtx.getAnimationDuration()).thenReturn(1000L);
         mResolver = new MockContentResolver();
@@ -1479,6 +1483,23 @@
     }
 
     @Test
+    @RequiresFlagsEnabled(FLAG_MAGNIFICATION_ENLARGE_POINTER)
+    public void persistScale_setValue_notifyInput() {
+        register(TEST_DISPLAY);
+
+        PointF pivotPoint = INITIAL_BOUNDS_LOWER_RIGHT_2X_CENTER;
+        mFullScreenMagnificationController.setScale(TEST_DISPLAY, 4.0f, pivotPoint.x, pivotPoint.y,
+                true, SERVICE_ID_1);
+        mFullScreenMagnificationController.persistScale(TEST_DISPLAY);
+
+        // persistScale may post a task to a background thread. Let's wait for it completes.
+        waitForBackgroundThread();
+        Assert.assertEquals(mFullScreenMagnificationController.getPersistedScale(TEST_DISPLAY),
+                4.0f);
+        verify(mMockInputManager).setAccessibilityPointerIconScaleFactor(TEST_DISPLAY, 4.0f);
+    }
+
+    @Test
     public void testOnContextChanged_alwaysOnFeatureDisabled_resetMagnification() {
         setScaleToMagnifying();
 
diff --git a/services/tests/servicestests/src/com/android/server/accessibility/magnification/FullScreenMagnificationGestureHandlerTest.java b/services/tests/servicestests/src/com/android/server/accessibility/magnification/FullScreenMagnificationGestureHandlerTest.java
index b745e6a..00b7de8 100644
--- a/services/tests/servicestests/src/com/android/server/accessibility/magnification/FullScreenMagnificationGestureHandlerTest.java
+++ b/services/tests/servicestests/src/com/android/server/accessibility/magnification/FullScreenMagnificationGestureHandlerTest.java
@@ -90,6 +90,7 @@
 import com.android.server.accessibility.EventStreamTransformation;
 import com.android.server.accessibility.Flags;
 import com.android.server.accessibility.magnification.FullScreenMagnificationController.MagnificationInfoChangedCallback;
+import com.android.server.input.InputManagerInternal;
 import com.android.server.testutils.OffsettableClock;
 import com.android.server.testutils.TestHandler;
 import com.android.server.wm.WindowManagerInternal;
@@ -227,9 +228,11 @@
         final FullScreenMagnificationController.ControllerContext mockController =
                 mock(FullScreenMagnificationController.ControllerContext.class);
         final WindowManagerInternal mockWindowManager = mock(WindowManagerInternal.class);
+        final InputManagerInternal mockInputManager = mock(InputManagerInternal.class);
         when(mockController.getContext()).thenReturn(mContext);
         when(mockController.getTraceManager()).thenReturn(mMockTraceManager);
         when(mockController.getWindowManager()).thenReturn(mockWindowManager);
+        when(mockController.getInputManager()).thenReturn(mockInputManager);
         when(mockController.getHandler()).thenReturn(new Handler(mContext.getMainLooper()));
         when(mockController.newValueAnimator()).thenReturn(new ValueAnimator());
         when(mockController.getAnimationDuration()).thenReturn(1000L);
diff --git a/services/tests/servicestests/src/com/android/server/accessibility/magnification/MagnificationControllerTest.java b/services/tests/servicestests/src/com/android/server/accessibility/magnification/MagnificationControllerTest.java
index 2528177..d70e1fe 100644
--- a/services/tests/servicestests/src/com/android/server/accessibility/magnification/MagnificationControllerTest.java
+++ b/services/tests/servicestests/src/com/android/server/accessibility/magnification/MagnificationControllerTest.java
@@ -76,6 +76,7 @@
 import com.android.server.accessibility.AccessibilityManagerService;
 import com.android.server.accessibility.AccessibilityTraceManager;
 import com.android.server.accessibility.test.MessageCapturingHandler;
+import com.android.server.input.InputManagerInternal;
 import com.android.server.wm.WindowManagerInternal;
 import com.android.window.flags.Flags;
 
@@ -154,6 +155,8 @@
     private WindowManagerInternal mWindowManagerInternal;
     @Mock
     private WindowManagerInternal.AccessibilityControllerInternal mA11yController;
+    @Mock
+    private InputManagerInternal mInputManagerInternal;
 
     @Mock
     private DisplayManagerInternal mDisplayManagerInternal;
@@ -200,6 +203,7 @@
         when(mControllerCtx.getContext()).thenReturn(mContext);
         when(mControllerCtx.getTraceManager()).thenReturn(mTraceManager);
         when(mControllerCtx.getWindowManager()).thenReturn(mWindowManagerInternal);
+        when(mControllerCtx.getInputManager()).thenReturn(mInputManagerInternal);
         when(mControllerCtx.getHandler()).thenReturn(mMessageCapturingHandler);
         when(mControllerCtx.getAnimationDuration()).thenReturn(1000L);
         when(mControllerCtx.newValueAnimator()).thenReturn(mValueAnimator);
diff --git a/tests/Input/Android.bp b/tests/Input/Android.bp
index 6742cbe..168141b 100644
--- a/tests/Input/Android.bp
+++ b/tests/Input/Android.bp
@@ -41,6 +41,7 @@
         "hamcrest-library",
         "junit-params",
         "kotlin-test",
+        "mockito-kotlin-nodeps",
         "mockito-target-extended-minus-junit4",
         "platform-test-annotations",
         "platform-screenshot-diff-core",
diff --git a/tests/Input/src/com/android/server/input/PointerIconCacheTest.kt b/tests/Input/src/com/android/server/input/PointerIconCacheTest.kt
new file mode 100644
index 0000000..47e7ac7
--- /dev/null
+++ b/tests/Input/src/com/android/server/input/PointerIconCacheTest.kt
@@ -0,0 +1,135 @@
+/*
+ * Copyright 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.input
+
+import android.content.Context
+import android.content.ContextWrapper
+import android.os.Handler
+import android.os.test.TestLooper
+import android.platform.test.annotations.Presubmit
+import android.view.Display
+import android.view.PointerIcon
+import androidx.test.platform.app.InstrumentationRegistry
+import junit.framework.Assert.assertEquals
+import org.junit.Before
+import org.junit.Rule
+import org.junit.Test
+import org.mockito.Mock
+import org.mockito.junit.MockitoJUnit
+import org.mockito.kotlin.times
+import org.mockito.kotlin.verify
+import org.mockito.kotlin.whenever
+
+/**
+ * Tests for {@link PointerIconCache}.
+ */
+@Presubmit
+class PointerIconCacheTest {
+
+    @get:Rule
+    val rule = MockitoJUnit.rule()!!
+
+    @Mock
+    private lateinit var native: NativeInputManagerService
+    @Mock
+    private lateinit var defaultDisplay: Display
+
+    private lateinit var context: Context
+    private lateinit var testLooper: TestLooper
+    private lateinit var cache: PointerIconCache
+
+    @Before
+    fun setup() {
+        whenever(defaultDisplay.displayId).thenReturn(Display.DEFAULT_DISPLAY)
+
+        context = object : ContextWrapper(InstrumentationRegistry.getInstrumentation().context) {
+            override fun getDisplay() = defaultDisplay
+        }
+
+        testLooper = TestLooper()
+        cache = PointerIconCache(context, native, Handler(testLooper.looper))
+    }
+
+    @Test
+    fun testSetPointerScale() {
+        val defaultBitmap = getDefaultIcon().bitmap
+        cache.setPointerScale(2f)
+
+        testLooper.dispatchAll()
+        verify(native).reloadPointerIcons()
+
+        val bitmap =
+            cache.getLoadedPointerIcon(Display.DEFAULT_DISPLAY, PointerIcon.TYPE_ARROW).bitmap
+
+        assertEquals(defaultBitmap.height * 2, bitmap.height)
+        assertEquals(defaultBitmap.width * 2, bitmap.width)
+    }
+
+    @Test
+    fun testSetAccessibilityScaleFactor() {
+        val defaultBitmap = getDefaultIcon().bitmap
+        cache.setAccessibilityScaleFactor(Display.DEFAULT_DISPLAY, 4f)
+
+        testLooper.dispatchAll()
+        verify(native).reloadPointerIcons()
+
+        val bitmap =
+            cache.getLoadedPointerIcon(Display.DEFAULT_DISPLAY, PointerIcon.TYPE_ARROW).bitmap
+
+        assertEquals(defaultBitmap.height * 4, bitmap.height)
+        assertEquals(defaultBitmap.width * 4, bitmap.width)
+    }
+
+    @Test
+    fun testSetAccessibilityScaleFactorOnSecondaryDisplay() {
+        val defaultBitmap = getDefaultIcon().bitmap
+        val secondaryDisplayId = Display.DEFAULT_DISPLAY + 1
+        cache.setAccessibilityScaleFactor(secondaryDisplayId, 4f)
+
+        testLooper.dispatchAll()
+        verify(native).reloadPointerIcons()
+
+        val bitmap =
+            cache.getLoadedPointerIcon(Display.DEFAULT_DISPLAY, PointerIcon.TYPE_ARROW).bitmap
+        assertEquals(defaultBitmap.height, bitmap.height)
+        assertEquals(defaultBitmap.width, bitmap.width)
+
+        val bitmapSecondary =
+            cache.getLoadedPointerIcon(secondaryDisplayId, PointerIcon.TYPE_ARROW).bitmap
+        assertEquals(defaultBitmap.height * 4, bitmapSecondary.height)
+        assertEquals(defaultBitmap.width * 4, bitmapSecondary.width)
+    }
+
+    @Test
+    fun testSetPointerScaleAndAccessibilityScaleFactor() {
+        val defaultBitmap = getDefaultIcon().bitmap
+        cache.setPointerScale(2f)
+        cache.setAccessibilityScaleFactor(Display.DEFAULT_DISPLAY, 3f)
+
+        testLooper.dispatchAll()
+        verify(native, times(2)).reloadPointerIcons()
+
+        val bitmap =
+            cache.getLoadedPointerIcon(Display.DEFAULT_DISPLAY, PointerIcon.TYPE_ARROW).bitmap
+
+        assertEquals(defaultBitmap.height * 6, bitmap.height)
+        assertEquals(defaultBitmap.width * 6, bitmap.width)
+    }
+
+    private fun getDefaultIcon() =
+        PointerIcon.getLoadedSystemIcon(context, PointerIcon.TYPE_ARROW, false, 1f)
+}