Refactor and fix flaky tests in WindowOnBackInvokedDispatcherTest

Test: atest WindowOnBackInvokedDispatcherTest  --iterations 90
Bug: 282209142
Change-Id: I725bdd67d74e902672c79e1911e58d65f3d825ac
diff --git a/core/java/android/window/WindowOnBackInvokedDispatcher.java b/core/java/android/window/WindowOnBackInvokedDispatcher.java
index 22b2ec0..421d998 100644
--- a/core/java/android/window/WindowOnBackInvokedDispatcher.java
+++ b/core/java/android/window/WindowOnBackInvokedDispatcher.java
@@ -31,6 +31,8 @@
 import android.view.IWindow;
 import android.view.IWindowSession;
 
+import androidx.annotation.VisibleForTesting;
+
 import java.io.PrintWriter;
 import java.lang.ref.WeakReference;
 import java.util.ArrayList;
@@ -66,7 +68,9 @@
     /** Convenience hashmap to quickly decide if a callback has been added. */
     private final HashMap<OnBackInvokedCallback, Integer> mAllCallbacks = new HashMap<>();
     /** Holds all callbacks by priorities. */
-    private final TreeMap<Integer, ArrayList<OnBackInvokedCallback>>
+
+    @VisibleForTesting
+    public final TreeMap<Integer, ArrayList<OnBackInvokedCallback>>
             mOnBackInvokedCallbacks = new TreeMap<>();
     private Checker mChecker;
 
diff --git a/core/tests/coretests/src/android/window/WindowOnBackInvokedDispatcherTest.java b/core/tests/coretests/src/android/window/WindowOnBackInvokedDispatcherTest.java
index 8e772a2..2ef2d3a 100644
--- a/core/tests/coretests/src/android/window/WindowOnBackInvokedDispatcherTest.java
+++ b/core/tests/coretests/src/android/window/WindowOnBackInvokedDispatcherTest.java
@@ -16,12 +16,15 @@
 
 package android.window;
 
+import static android.window.OnBackInvokedDispatcher.PRIORITY_DEFAULT;
+import static android.window.OnBackInvokedDispatcher.PRIORITY_OVERLAY;
+
 import static org.junit.Assert.assertEquals;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.isNull;
+import static org.mockito.Mockito.atLeast;
+import static org.mockito.Mockito.atMost;
 import static org.mockito.Mockito.doReturn;
-import static org.mockito.Mockito.reset;
-import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.verifyZeroInteractions;
@@ -45,6 +48,9 @@
 import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
 
+import java.util.ArrayList;
+import java.util.List;
+
 /**
  * Tests for {@link WindowOnBackInvokedDispatcherTest}
  *
@@ -69,6 +75,8 @@
     @Mock
     private ApplicationInfo mApplicationInfo;
 
+    private int mCallbackInfoCalls = 0;
+
     private final BackMotionEvent mBackEvent = new BackMotionEvent(
             /* touchX = */ 0,
             /* touchY = */ 0,
@@ -93,105 +101,243 @@
         InstrumentationRegistry.getInstrumentation().waitForIdleSync();
     }
 
+    private List<OnBackInvokedCallbackInfo> captureCallbackInfo() throws RemoteException {
+        ArgumentCaptor<OnBackInvokedCallbackInfo> captor = ArgumentCaptor
+                .forClass(OnBackInvokedCallbackInfo.class);
+        // atLeast(0) -> get all setOnBackInvokedCallbackInfo() invocations
+        verify(mWindowSession, atLeast(0))
+                .setOnBackInvokedCallbackInfo(Mockito.eq(mWindow), captor.capture());
+        verifyNoMoreInteractions(mWindowSession);
+        return captor.getAllValues();
+    }
+
+    private OnBackInvokedCallbackInfo assertSetCallbackInfo() throws RemoteException {
+        List<OnBackInvokedCallbackInfo> callbackInfos = captureCallbackInfo();
+        int actual = callbackInfos.size();
+        assertEquals("setOnBackInvokedCallbackInfo", ++mCallbackInfoCalls, actual);
+        return callbackInfos.get(mCallbackInfoCalls - 1);
+    }
+
+    private void assertNoSetCallbackInfo() throws RemoteException {
+        List<OnBackInvokedCallbackInfo> callbackInfos = captureCallbackInfo();
+        int actual = callbackInfos.size();
+        assertEquals("No setOnBackInvokedCallbackInfo", mCallbackInfoCalls, actual);
+    }
+
+    private void assertCallbacksSize(int expectedDefault, int expectedOverlay) {
+        ArrayList<OnBackInvokedCallback> callbacksDefault = mDispatcher
+                .mOnBackInvokedCallbacks.get(PRIORITY_DEFAULT);
+        int actualSizeDefault = callbacksDefault != null ? callbacksDefault.size() : 0;
+        assertEquals("mOnBackInvokedCallbacks DEFAULT size", expectedDefault, actualSizeDefault);
+
+        ArrayList<OnBackInvokedCallback> callbacksOverlay = mDispatcher
+                .mOnBackInvokedCallbacks.get(PRIORITY_OVERLAY);
+        int actualSizeOverlay = callbacksOverlay != null ? callbacksOverlay.size() : 0;
+        assertEquals("mOnBackInvokedCallbacks OVERLAY size", expectedOverlay, actualSizeOverlay);
+    }
+
+    private void assertTopCallback(OnBackInvokedCallback expectedCallback) {
+        assertEquals("topCallback", expectedCallback, mDispatcher.getTopCallback());
+    }
+
+    @Test
+    public void registerCallback_samePriority_sameCallback() throws RemoteException {
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback1);
+        assertCallbacksSize(/* default */ 1, /* overlay */ 0);
+        assertSetCallbackInfo();
+        assertTopCallback(mCallback1);
+
+        // The callback is removed and added again
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback1);
+        assertCallbacksSize(/* default */ 1, /* overlay */ 0);
+        assertSetCallbackInfo();
+        assertTopCallback(mCallback1);
+
+        waitForIdle();
+        verifyNoMoreInteractions(mWindowSession);
+        verifyNoMoreInteractions(mCallback1);
+    }
+
+    @Test
+    public void registerCallback_samePriority_differentCallback() throws RemoteException {
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback1);
+        assertCallbacksSize(/* default */ 1, /* overlay */ 0);
+        assertSetCallbackInfo();
+        assertTopCallback(mCallback1);
+
+        // The new callback becomes the TopCallback
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback2);
+        assertCallbacksSize(/* default */ 2, /* overlay */ 0);
+        assertSetCallbackInfo();
+        assertTopCallback(mCallback2);
+
+        waitForIdle();
+        verifyNoMoreInteractions(mWindowSession);
+        verifyNoMoreInteractions(mCallback1);
+        verifyNoMoreInteractions(mCallback2);
+    }
+
+    @Test
+    public void registerCallback_differentPriority_sameCallback() throws RemoteException {
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_OVERLAY, mCallback1);
+        assertCallbacksSize(/* default */ 0, /* overlay */ 1);
+        assertSetCallbackInfo();
+        assertTopCallback(mCallback1);
+
+        // The callback is moved to the new priority list
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback1);
+        assertCallbacksSize(/* default */ 1, /* overlay */ 0);
+        assertSetCallbackInfo();
+        assertTopCallback(mCallback1);
+
+        waitForIdle();
+        verifyNoMoreInteractions(mWindowSession);
+        verifyNoMoreInteractions(mCallback1);
+    }
+
+    @Test
+    public void registerCallback_differentPriority_differentCallback() throws RemoteException {
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_OVERLAY, mCallback1);
+        assertSetCallbackInfo();
+        assertCallbacksSize(/* default */ 0, /* overlay */ 1);
+        assertTopCallback(mCallback1);
+
+        // The callback with higher priority is still the TopCallback
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback2);
+        assertNoSetCallbackInfo();
+        assertCallbacksSize(/* default */ 1, /* overlay */ 1);
+        assertTopCallback(mCallback1);
+
+        waitForIdle();
+        verifyNoMoreInteractions(mWindowSession);
+        verifyNoMoreInteractions(mCallback1);
+        verifyNoMoreInteractions(mCallback2);
+    }
+
+    @Test
+    public void registerCallback_sameInstanceAddedTwice() throws RemoteException {
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_OVERLAY, mCallback1);
+        assertCallbacksSize(/* default */ 0, /* overlay */ 1);
+        assertSetCallbackInfo();
+        assertTopCallback(mCallback1);
+
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback2);
+        assertCallbacksSize(/* default */ 1, /* overlay */ 1);
+        assertNoSetCallbackInfo();
+        assertTopCallback(mCallback1);
+
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback1);
+        assertCallbacksSize(/* default */ 2, /* overlay */ 0);
+        assertSetCallbackInfo();
+        assertTopCallback(mCallback1);
+
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_OVERLAY, mCallback2);
+        assertCallbacksSize(/* default */ 1, /* overlay */ 1);
+        assertSetCallbackInfo();
+        assertTopCallback(mCallback2);
+
+        waitForIdle();
+        verifyNoMoreInteractions(mWindowSession);
+        verifyNoMoreInteractions(mCallback1);
+        verifyNoMoreInteractions(mCallback2);
+    }
+
     @Test
     public void propagatesTopCallback_samePriority() throws RemoteException {
-        ArgumentCaptor<OnBackInvokedCallbackInfo> captor =
-                ArgumentCaptor.forClass(OnBackInvokedCallbackInfo.class);
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback1);
+        OnBackInvokedCallbackInfo callbackInfo1 = assertSetCallbackInfo();
 
-        mDispatcher.registerOnBackInvokedCallback(
-                OnBackInvokedDispatcher.PRIORITY_DEFAULT, mCallback1);
-        mDispatcher.registerOnBackInvokedCallback(
-                OnBackInvokedDispatcher.PRIORITY_DEFAULT, mCallback2);
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback2);
+        OnBackInvokedCallbackInfo callbackInfo2 = assertSetCallbackInfo();
 
-        verify(mWindowSession, times(2)).setOnBackInvokedCallbackInfo(
-                Mockito.eq(mWindow),
-                captor.capture());
-        captor.getAllValues().get(0).getCallback().onBackStarted(mBackEvent);
+        callbackInfo1.getCallback().onBackStarted(mBackEvent);
+
         waitForIdle();
         verify(mCallback1).onBackStarted(any(BackEvent.class));
         verifyZeroInteractions(mCallback2);
 
-        captor.getAllValues().get(1).getCallback().onBackStarted(mBackEvent);
+        callbackInfo2.getCallback().onBackStarted(mBackEvent);
+
         waitForIdle();
         verify(mCallback2).onBackStarted(any(BackEvent.class));
+
+        // Calls sequence: BackProgressAnimator.onBackStarted() -> BackProgressAnimator.reset() ->
+        // Spring.animateToFinalPosition(0). This causes a progress event to be fired.
+        verify(mCallback1, atMost(1)).onBackProgressed(any(BackEvent.class));
         verifyNoMoreInteractions(mCallback1);
     }
 
     @Test
     public void propagatesTopCallback_differentPriority() throws RemoteException {
-        ArgumentCaptor<OnBackInvokedCallbackInfo> captor =
-                ArgumentCaptor.forClass(OnBackInvokedCallbackInfo.class);
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_OVERLAY, mCallback1);
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback2);
 
-        mDispatcher.registerOnBackInvokedCallback(
-                OnBackInvokedDispatcher.PRIORITY_OVERLAY, mCallback1);
-        mDispatcher.registerOnBackInvokedCallback(
-                OnBackInvokedDispatcher.PRIORITY_DEFAULT, mCallback2);
+        OnBackInvokedCallbackInfo callbackInfo = assertSetCallbackInfo();
 
-        verify(mWindowSession).setOnBackInvokedCallbackInfo(
-                Mockito.eq(mWindow), captor.capture());
         verifyNoMoreInteractions(mWindowSession);
-        assertEquals(captor.getValue().getPriority(), OnBackInvokedDispatcher.PRIORITY_OVERLAY);
-        captor.getValue().getCallback().onBackStarted(mBackEvent);
+        assertEquals(callbackInfo.getPriority(), PRIORITY_OVERLAY);
+
+        callbackInfo.getCallback().onBackStarted(mBackEvent);
+
         waitForIdle();
         verify(mCallback1).onBackStarted(any(BackEvent.class));
     }
 
     @Test
     public void propagatesTopCallback_withRemoval() throws RemoteException {
-        mDispatcher.registerOnBackInvokedCallback(
-                OnBackInvokedDispatcher.PRIORITY_DEFAULT, mCallback1);
-        mDispatcher.registerOnBackInvokedCallback(
-                OnBackInvokedDispatcher.PRIORITY_DEFAULT, mCallback2);
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback1);
+        assertSetCallbackInfo();
 
-        reset(mWindowSession);
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback2);
+        assertSetCallbackInfo();
+
         mDispatcher.unregisterOnBackInvokedCallback(mCallback1);
-        verifyZeroInteractions(mWindowSession);
+
+        waitForIdle();
+        verifyNoMoreInteractions(mWindowSession);
+        verifyNoMoreInteractions(mCallback1);
 
         mDispatcher.unregisterOnBackInvokedCallback(mCallback2);
+
+        waitForIdle();
         verify(mWindowSession).setOnBackInvokedCallbackInfo(Mockito.eq(mWindow), isNull());
     }
 
 
     @Test
     public void propagatesTopCallback_sameInstanceAddedTwice() throws RemoteException {
-        ArgumentCaptor<OnBackInvokedCallbackInfo> captor =
-                ArgumentCaptor.forClass(OnBackInvokedCallbackInfo.class);
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_OVERLAY, mCallback1);
+        assertSetCallbackInfo();
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback2);
+        assertNoSetCallbackInfo();
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback1);
+        assertSetCallbackInfo();
 
-        mDispatcher.registerOnBackInvokedCallback(OnBackInvokedDispatcher.PRIORITY_OVERLAY,
-                mCallback1
-        );
-        mDispatcher.registerOnBackInvokedCallback(
-                OnBackInvokedDispatcher.PRIORITY_DEFAULT, mCallback2);
-        mDispatcher.registerOnBackInvokedCallback(
-                OnBackInvokedDispatcher.PRIORITY_DEFAULT, mCallback1);
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_OVERLAY, mCallback2);
 
-        reset(mWindowSession);
-        mDispatcher.registerOnBackInvokedCallback(
-                OnBackInvokedDispatcher.PRIORITY_OVERLAY, mCallback2);
-        verify(mWindowSession).setOnBackInvokedCallbackInfo(Mockito.eq(mWindow), captor.capture());
-        captor.getValue().getCallback().onBackStarted(mBackEvent);
+        OnBackInvokedCallbackInfo lastCallbackInfo = assertSetCallbackInfo();
+
+        lastCallbackInfo.getCallback().onBackStarted(mBackEvent);
+
         waitForIdle();
         verify(mCallback2).onBackStarted(any(BackEvent.class));
     }
 
     @Test
     public void onUnregisterWhileBackInProgress_callOnBackCancelled() throws RemoteException {
-        ArgumentCaptor<OnBackInvokedCallbackInfo> captor =
-                ArgumentCaptor.forClass(OnBackInvokedCallbackInfo.class);
+        mDispatcher.registerOnBackInvokedCallback(PRIORITY_DEFAULT, mCallback1);
 
-        mDispatcher.registerOnBackInvokedCallback(
-                OnBackInvokedDispatcher.PRIORITY_DEFAULT, mCallback1);
+        OnBackInvokedCallbackInfo callbackInfo = assertSetCallbackInfo();
 
-        verify(mWindowSession).setOnBackInvokedCallbackInfo(
-                Mockito.eq(mWindow),
-                captor.capture());
-        IOnBackInvokedCallback iOnBackInvokedCallback = captor.getValue().getCallback();
-        iOnBackInvokedCallback.onBackStarted(mBackEvent);
+        callbackInfo.getCallback().onBackStarted(mBackEvent);
+
         waitForIdle();
         verify(mCallback1).onBackStarted(any(BackEvent.class));
 
         mDispatcher.unregisterOnBackInvokedCallback(mCallback1);
+
+        waitForIdle();
         verify(mCallback1).onBackCancelled();
-        verifyNoMoreInteractions(mCallback1);
+        verify(mWindowSession).setOnBackInvokedCallbackInfo(Mockito.eq(mWindow), isNull());
     }
 }