Update TaskView corner radius dynamically

Previously we set the corner radius just once when TaskView is
constructed, but this doesn't work when reusing the TaskView on a
different display that has a different corner radius.

Flag: none
Test: FullScreenDrawParamsTest
Test: kill launcher, open Overview on one display, then switch to
another display with different corner radii and ensure task radii
have updated while quick switching
Fixes: 293224095

Change-Id: I5f0697a4697400ec0e003c116774d74a945ee59e
diff --git a/quickstep/src/com/android/quickstep/views/DesktopTaskView.java b/quickstep/src/com/android/quickstep/views/DesktopTaskView.java
index 5f3fd0c..dfd4390 100644
--- a/quickstep/src/com/android/quickstep/views/DesktopTaskView.java
+++ b/quickstep/src/com/android/quickstep/views/DesktopTaskView.java
@@ -109,9 +109,17 @@
     public DesktopTaskView(Context context, AttributeSet attrs, int defStyleAttr) {
         super(context, attrs, defStyleAttr);
 
-        mSnapshotDrawParams = new FullscreenDrawParams(
-                QuickStepContract.getWindowCornerRadius(context),
-                QuickStepContract.getWindowCornerRadius(context));
+        mSnapshotDrawParams = new FullscreenDrawParams(context) {
+            @Override
+            public float computeTaskCornerRadius(Context context) {
+                return QuickStepContract.getWindowCornerRadius(context);
+            }
+
+            @Override
+            public float computeWindowCornerRadius(Context context) {
+                return QuickStepContract.getWindowCornerRadius(context);
+            }
+        };
     }
 
     @Override
diff --git a/quickstep/src/com/android/quickstep/views/TaskView.java b/quickstep/src/com/android/quickstep/views/TaskView.java
index 854c3c7..a2976a8 100644
--- a/quickstep/src/com/android/quickstep/views/TaskView.java
+++ b/quickstep/src/com/android/quickstep/views/TaskView.java
@@ -71,6 +71,7 @@
 import androidx.annotation.IntDef;
 import androidx.annotation.NonNull;
 import androidx.annotation.Nullable;
+import androidx.annotation.VisibleForTesting;
 
 import com.android.app.animation.Interpolators;
 import com.android.launcher3.DeviceProfile;
@@ -133,15 +134,17 @@
 
     public static final int FLAG_UPDATE_ICON = 1;
     public static final int FLAG_UPDATE_THUMBNAIL = FLAG_UPDATE_ICON << 1;
+    public static final int FLAG_UPDATE_CORNER_RADIUS = FLAG_UPDATE_THUMBNAIL << 1;
 
-    public static final int FLAG_UPDATE_ALL = FLAG_UPDATE_ICON | FLAG_UPDATE_THUMBNAIL;
+    public static final int FLAG_UPDATE_ALL = FLAG_UPDATE_ICON | FLAG_UPDATE_THUMBNAIL
+            | FLAG_UPDATE_CORNER_RADIUS;
 
     /**
      * Used in conjunction with {@link #onTaskListVisibilityChanged(boolean, int)}, providing more
      * granularity on which components of this task require an update
      */
     @Retention(SOURCE)
-    @IntDef({FLAG_UPDATE_ALL, FLAG_UPDATE_ICON, FLAG_UPDATE_THUMBNAIL})
+    @IntDef({FLAG_UPDATE_ALL, FLAG_UPDATE_ICON, FLAG_UPDATE_THUMBNAIL, FLAG_UPDATE_CORNER_RADIUS})
     public @interface TaskDataChanges {}
 
     /**
@@ -1079,6 +1082,9 @@
                             mDigitalWellBeingToast.initialize(task);
                         });
             }
+            if (needsUpdate(changes, FLAG_UPDATE_CORNER_RADIUS)) {
+                mCurrentFullscreenParams.updateCornerRadius(getContext());
+            }
         } else {
             if (needsUpdate(changes, FLAG_UPDATE_THUMBNAIL)) {
                 mSnapshotView.setThumbnail(null, null);
@@ -1859,19 +1865,29 @@
      */
     public static class FullscreenDrawParams {
 
-        private final float mCornerRadius;
-        private final float mWindowCornerRadius;
+        private float mCornerRadius;
+        private float mWindowCornerRadius;
 
         public float mCurrentDrawnCornerRadius;
 
         public FullscreenDrawParams(Context context) {
-            this(TaskCornerRadius.get(context), QuickStepContract.getWindowCornerRadius(context));
+            updateCornerRadius(context);
         }
 
-        FullscreenDrawParams(float cornerRadius, float windowCornerRadius) {
-            mCornerRadius = cornerRadius;
-            mWindowCornerRadius = windowCornerRadius;
-            mCurrentDrawnCornerRadius = mCornerRadius;
+        /** Recomputes the start and end corner radius for the given Context. */
+        public void updateCornerRadius(Context context) {
+            mCornerRadius = computeTaskCornerRadius(context);
+            mWindowCornerRadius = computeWindowCornerRadius(context);
+        }
+
+        @VisibleForTesting
+        public float computeTaskCornerRadius(Context context) {
+            return TaskCornerRadius.get(context);
+        }
+
+        @VisibleForTesting
+        public float computeWindowCornerRadius(Context context) {
+            return QuickStepContract.getWindowCornerRadius(context);
         }
 
         /**
diff --git a/quickstep/tests/src/com/android/quickstep/FullscreenDrawParamsTest.kt b/quickstep/tests/src/com/android/quickstep/FullscreenDrawParamsTest.kt
index a9dc043..bb1afdf 100644
--- a/quickstep/tests/src/com/android/quickstep/FullscreenDrawParamsTest.kt
+++ b/quickstep/tests/src/com/android/quickstep/FullscreenDrawParamsTest.kt
@@ -15,6 +15,7 @@
  */
 package com.android.quickstep
 
+import android.content.Context
 import android.graphics.Rect
 import androidx.test.ext.junit.runners.AndroidJUnit4
 import androidx.test.filters.SmallTest
@@ -29,7 +30,9 @@
 import org.junit.Before
 import org.junit.Test
 import org.junit.runner.RunWith
+import org.mockito.Mockito.doReturn
 import org.mockito.Mockito.mock
+import org.mockito.Mockito.spy
 
 /** Test for FullscreenDrawParams class. */
 @SmallTest
@@ -186,4 +189,76 @@
         val expectedRadius = QuickStepContract.getWindowCornerRadius(context)
         assertThat(params.mCurrentDrawnCornerRadius).isEqualTo(expectedRadius)
     }
+
+    @Test
+    fun setStartProgress_correctCornerRadiusForMultiDisplay() {
+        val display1Context = context
+        val display2Context = mock(Context::class.java)
+        val spyParams = spy(params)
+
+        val display1TaskRadius = TaskCornerRadius.get(display1Context)
+        val display1WindowRadius = QuickStepContract.getWindowCornerRadius(display1Context)
+        val display2TaskRadius = display1TaskRadius * 2 + 1 // Arbitrarily different.
+        val display2WindowRadius = display1WindowRadius * 2 + 1 // Arbitrarily different.
+        doReturn(display2TaskRadius).`when`(spyParams).computeTaskCornerRadius(display2Context)
+        doReturn(display2WindowRadius).`when`(spyParams).computeWindowCornerRadius(display2Context)
+
+        spyParams.updateCornerRadius(display1Context)
+        spyParams.setProgress(
+            /* fullscreenProgress= */ 0f,
+            /* parentScale= */ 1.0f,
+            /* taskViewScale= */ 1.0f,
+            /* unused previewWidth= */ -1,
+            /* unusedDp= */ null,
+            /* unused previewPositionHelper= */ null
+        )
+        assertThat(spyParams.mCurrentDrawnCornerRadius).isEqualTo(display1TaskRadius)
+
+        spyParams.updateCornerRadius(display2Context)
+        spyParams.setProgress(
+            /* fullscreenProgress= */ 0f,
+            /* parentScale= */ 1.0f,
+            /* taskViewScale= */ 1.0f,
+            /* unused previewWidth= */ -1,
+            /* unusedDp= */ null,
+            /* unused previewPositionHelper= */ null
+        )
+        assertThat(spyParams.mCurrentDrawnCornerRadius).isEqualTo(display2TaskRadius)
+    }
+
+    @Test
+    fun setFullProgress_correctCornerRadiusForMultiDisplay() {
+        val display1Context = context
+        val display2Context = mock(Context::class.java)
+        val spyParams = spy(params)
+
+        val display1TaskRadius = TaskCornerRadius.get(display1Context)
+        val display1WindowRadius = QuickStepContract.getWindowCornerRadius(display1Context)
+        val display2TaskRadius = display1TaskRadius * 2 + 1 // Arbitrarily different.
+        val display2WindowRadius = display1WindowRadius * 2 + 1 // Arbitrarily different.
+        doReturn(display2TaskRadius).`when`(spyParams).computeTaskCornerRadius(display2Context)
+        doReturn(display2WindowRadius).`when`(spyParams).computeWindowCornerRadius(display2Context)
+
+        spyParams.updateCornerRadius(display1Context)
+        spyParams.setProgress(
+            /* fullscreenProgress= */ 1.0f,
+            /* parentScale= */ 1.0f,
+            /* taskViewScale= */ 1.0f,
+            /* unused previewWidth= */ -1,
+            /* unusedDp= */ null,
+            /* unused previewPositionHelper= */ null
+        )
+        assertThat(spyParams.mCurrentDrawnCornerRadius).isEqualTo(display1WindowRadius)
+
+        spyParams.updateCornerRadius(display2Context)
+        spyParams.setProgress(
+            /* fullscreenProgress= */ 1.0f,
+            /* parentScale= */ 1.0f,
+            /* taskViewScale= */ 1.0f,
+            /* unused previewWidth= */ -1,
+            /* unusedDp= */ null,
+            /* unused previewPositionHelper= */ null
+        )
+        assertThat(spyParams.mCurrentDrawnCornerRadius).isEqualTo(display2WindowRadius)
+    }
 }