Merge "Listen to live cancellation signal" into udc-dev
diff --git a/packages/CredentialManager/src/com/android/credentialmanager/CredentialSelectorActivity.kt b/packages/CredentialManager/src/com/android/credentialmanager/CredentialSelectorActivity.kt
index 6f5015d..24f92c0 100644
--- a/packages/CredentialManager/src/com/android/credentialmanager/CredentialSelectorActivity.kt
+++ b/packages/CredentialManager/src/com/android/credentialmanager/CredentialSelectorActivity.kt
@@ -5,7 +5,7 @@
  * you may not use this file except in compliance with the License.
  * You may obtain a copy of the License at
  *
- *      http://www.apache.org/licenses/LICENSE-2.0
+ *      http://www.apache.org/licenses/LICENSE-2.0N
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -74,7 +74,6 @@
     override fun onNewIntent(intent: Intent) {
         super.onNewIntent(intent)
         setIntent(intent)
-        Log.d(Constants.LOG_TAG, "Existing activity received new intent")
         try {
             val viewModel: CredentialSelectorViewModel by viewModels()
             val (isCancellationRequest, shouldShowCancellationUi, appDisplayName) =
diff --git a/services/credentials/java/com/android/server/credentials/ClearRequestSession.java b/services/credentials/java/com/android/server/credentials/ClearRequestSession.java
index dce7b87..bbebbf2 100644
--- a/services/credentials/java/com/android/server/credentials/ClearRequestSession.java
+++ b/services/credentials/java/com/android/server/credentials/ClearRequestSession.java
@@ -40,11 +40,13 @@
         implements ProviderSession.ProviderInternalCallback<Void> {
     private static final String TAG = "GetRequestSession";
 
-    public ClearRequestSession(Context context, int userId, int callingUid,
+    public ClearRequestSession(Context context, RequestSession.SessionLifetime sessionCallback,
+            Object lock, int userId, int callingUid,
             IClearCredentialStateCallback callback, ClearCredentialStateRequest request,
             CallingAppInfo callingAppInfo, CancellationSignal cancellationSignal,
             long startedTimestamp) {
-        super(context, userId, callingUid, request, callback, RequestInfo.TYPE_UNDEFINED,
+        super(context, sessionCallback, lock, userId, callingUid, request, callback,
+                RequestInfo.TYPE_UNDEFINED,
                 callingAppInfo, cancellationSignal, startedTimestamp);
     }
 
diff --git a/services/credentials/java/com/android/server/credentials/CreateRequestSession.java b/services/credentials/java/com/android/server/credentials/CreateRequestSession.java
index 98dc8ab..4c456a8 100644
--- a/services/credentials/java/com/android/server/credentials/CreateRequestSession.java
+++ b/services/credentials/java/com/android/server/credentials/CreateRequestSession.java
@@ -49,13 +49,15 @@
         implements ProviderSession.ProviderInternalCallback<CreateCredentialResponse> {
     private static final String TAG = "CreateRequestSession";
 
-    CreateRequestSession(@NonNull Context context, int userId, int callingUid,
+    CreateRequestSession(@NonNull Context context, RequestSession.SessionLifetime sessionCallback,
+            Object lock, int userId, int callingUid,
             CreateCredentialRequest request,
             ICreateCredentialCallback callback,
             CallingAppInfo callingAppInfo,
             CancellationSignal cancellationSignal,
             long startedTimestamp) {
-        super(context, userId, callingUid, request, callback, RequestInfo.TYPE_CREATE,
+        super(context, sessionCallback, lock, userId, callingUid, request, callback,
+                RequestInfo.TYPE_CREATE,
                 callingAppInfo, cancellationSignal, startedTimestamp);
     }
 
@@ -83,6 +85,7 @@
     @Override
     protected void launchUiWithProviderData(ArrayList<ProviderData> providerDataList) {
         mRequestSessionMetric.collectUiCallStartTime(System.nanoTime());
+        mCredentialManagerUi.setStatus(CredentialManagerUi.UiStatus.USER_INTERACTION);
         try {
             mClientCallback.onPendingIntent(mCredentialManagerUi.createPendingIntent(
                     RequestInfo.newCreateRequestInfo(
@@ -93,6 +96,7 @@
                     providerDataList));
         } catch (RemoteException e) {
             mRequestSessionMetric.collectUiReturnedFinalPhase(/*uiReturned=*/ false);
+            mCredentialManagerUi.setStatus(CredentialManagerUi.UiStatus.TERMINATED);
             respondToClientWithErrorAndFinish(
                     CreateCredentialException.TYPE_UNKNOWN,
                     "Unable to invoke selector");
diff --git a/services/credentials/java/com/android/server/credentials/CredentialManagerService.java b/services/credentials/java/com/android/server/credentials/CredentialManagerService.java
index de06d44..0aaa1d2 100644
--- a/services/credentials/java/com/android/server/credentials/CredentialManagerService.java
+++ b/services/credentials/java/com/android/server/credentials/CredentialManagerService.java
@@ -33,7 +33,6 @@
 import android.credentials.ClearCredentialStateRequest;
 import android.credentials.CreateCredentialException;
 import android.credentials.CreateCredentialRequest;
-import android.credentials.CredentialManager;
 import android.credentials.CredentialOption;
 import android.credentials.CredentialProviderInfo;
 import android.credentials.GetCredentialException;
@@ -50,6 +49,7 @@
 import android.credentials.ui.IntentFactory;
 import android.os.Binder;
 import android.os.CancellationSignal;
+import android.os.IBinder;
 import android.os.ICancellationSignal;
 import android.os.RemoteException;
 import android.os.UserHandle;
@@ -70,9 +70,11 @@
 import com.android.server.infra.SecureSettingsServiceNameResolver;
 
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import java.util.function.Consumer;
 import java.util.stream.Collectors;
@@ -102,6 +104,13 @@
     private final SparseArray<List<CredentialManagerServiceImpl>> mSystemServicesCacheList =
             new SparseArray<>();
 
+    /** Cache of all ongoing request sessions per user id. */
+    @GuardedBy("mLock")
+    private final SparseArray<Map<IBinder, RequestSession>> mRequestSessions =
+            new SparseArray<>();
+
+    private final SessionManager mSessionManager = new SessionManager();
+
     public CredentialManagerService(@NonNull Context context) {
         super(
                 context,
@@ -331,7 +340,7 @@
 
     @NonNull
     private Set<Pair<CredentialOption, CredentialDescriptionRegistry.FilterResult>>
-            getFilteredResultFromRegistry(List<CredentialOption> options) {
+    getFilteredResultFromRegistry(List<CredentialOption> options) {
         // Session for active/provisioned credential descriptions;
         CredentialDescriptionRegistry registry =
                 CredentialDescriptionRegistry.forUser(UserHandle.getCallingUserId());
@@ -389,14 +398,6 @@
         return providerSessions;
     }
 
-    private List<CredentialProviderInfo> getServicesForCredentialDescription(int userId) {
-        return CredentialProviderInfoFactory.getCredentialProviderServices(
-                mContext,
-                userId,
-                CredentialManager.PROVIDER_FILTER_ALL_PROVIDERS,
-                new HashSet<>());
-    }
-
     @Override
     @GuardedBy("CredentialDescriptionRegistry.sLock")
     public void onUserStopped(@NonNull TargetUser user) {
@@ -448,6 +449,8 @@
             final GetRequestSession session =
                     new GetRequestSession(
                             getContext(),
+                            mSessionManager,
+                            mLock,
                             userId,
                             callingUid,
                             callback,
@@ -455,6 +458,7 @@
                             constructCallingAppInfo(callingPackage, userId, request.getOrigin()),
                             CancellationSignal.fromTransport(cancelTransport),
                             timestampBegan);
+            addSessionLocked(userId, session);
 
             List<ProviderSession> providerSessions =
                     prepareProviderSessions(request, session);
@@ -499,6 +503,8 @@
             final PrepareGetRequestSession session =
                     new PrepareGetRequestSession(
                             getContext(),
+                            mSessionManager,
+                            mLock,
                             userId,
                             callingUid,
                             getCredentialCallback,
@@ -515,8 +521,8 @@
                     // TODO: fix
                     prepareGetCredentialCallback.onResponse(
                             new PrepareGetCredentialResponseInternal(
-                            false, null,
-                            false, false, null));
+                                    false, null,
+                                    false, false, null));
                 } catch (RemoteException e) {
                     Log.i(
                             TAG,
@@ -540,10 +546,10 @@
                 List<CredentialOption> optionsThatRequireActiveCredentials =
                         request.getCredentialOptions().stream()
                                 .filter(credentialOption -> credentialOption
-                                                .getCredentialRetrievalData()
-                                                .getStringArrayList(
-                                                        CredentialOption
-                                                                .SUPPORTED_ELEMENT_KEYS) != null)
+                                        .getCredentialRetrievalData()
+                                        .getStringArrayList(
+                                                CredentialOption
+                                                        .SUPPORTED_ELEMENT_KEYS) != null)
                                 .toList();
 
                 List<CredentialOption> optionsThatDoNotRequireActiveCredentials =
@@ -614,6 +620,8 @@
             final CreateRequestSession session =
                     new CreateRequestSession(
                             getContext(),
+                            mSessionManager,
+                            mLock,
                             userId,
                             callingUid,
                             request,
@@ -621,6 +629,7 @@
                             constructCallingAppInfo(callingPackage, userId, request.getOrigin()),
                             CancellationSignal.fromTransport(cancelTransport),
                             timestampBegan);
+            addSessionLocked(userId, session);
 
             processCreateCredential(request, callback, session);
             return cancelTransport;
@@ -815,6 +824,8 @@
             final ClearRequestSession session =
                     new ClearRequestSession(
                             getContext(),
+                            mSessionManager,
+                            mLock,
                             userId,
                             callingUid,
                             callback,
@@ -822,6 +833,7 @@
                             constructCallingAppInfo(callingPackage, userId, null),
                             CancellationSignal.fromTransport(cancelTransport),
                             timestampBegan);
+            addSessionLocked(userId, session);
 
             // Initiate all provider sessions
             // TODO: Determine if provider needs to have clear capability in their manifest
@@ -905,6 +917,13 @@
         }
     }
 
+    private void addSessionLocked(@UserIdInt int userId,
+            RequestSession requestSession) {
+        synchronized (mLock) {
+            mSessionManager.addSession(userId, requestSession.mRequestId, requestSession);
+        }
+    }
+
     private void enforceCallingPackage(String callingPackage, int callingUid) {
         int packageUid;
         PackageManager pm = mContext.createContextAsUser(
@@ -919,4 +938,23 @@
             throw new SecurityException(callingPackage + " does not belong to uid " + callingUid);
         }
     }
+
+    private class SessionManager implements RequestSession.SessionLifetime {
+        @Override
+        @GuardedBy("mLock")
+        public void onFinishRequestSession(@UserIdInt int userId, IBinder token) {
+            Log.i(TAG, "In onFinishRequestSession");
+            if (mRequestSessions.get(userId) != null) {
+                mRequestSessions.get(userId).remove(token);
+            }
+        }
+
+        @GuardedBy("mLock")
+        public void addSession(int userId, IBinder token, RequestSession requestSession) {
+            if (mRequestSessions.get(userId) == null) {
+                mRequestSessions.put(userId, new HashMap<>());
+            }
+            mRequestSessions.get(userId).put(token, requestSession);
+        }
+    }
 }
diff --git a/services/credentials/java/com/android/server/credentials/CredentialManagerUi.java b/services/credentials/java/com/android/server/credentials/CredentialManagerUi.java
index 546c37f..e16d48e 100644
--- a/services/credentials/java/com/android/server/credentials/CredentialManagerUi.java
+++ b/services/credentials/java/com/android/server/credentials/CredentialManagerUi.java
@@ -30,6 +30,7 @@
 import android.credentials.ui.UserSelectionDialogResult;
 import android.os.Bundle;
 import android.os.Handler;
+import android.os.IBinder;
 import android.os.Looper;
 import android.os.ResultReceiver;
 import android.service.credentials.CredentialProviderInfoFactory;
@@ -50,6 +51,20 @@
     @NonNull private final Context mContext;
     // TODO : Use for starting the activity for this user
     private final int mUserId;
+
+    private UiStatus mStatus;
+
+    /** Creates intent that is ot be invoked to cancel an in-progress UI session. */
+    public Intent createCancelIntent(IBinder requestId, String packageName) {
+        return IntentFactory.createCancelUiIntent(requestId, /*shouldShowCancellationUi=*/ true,
+                packageName);
+    }
+
+    enum UiStatus {
+        IN_PROGRESS,
+        USER_INTERACTION,
+        NOT_STARTED, TERMINATED
+    }
     @NonNull private final ResultReceiver mResultReceiver = new ResultReceiver(
             new Handler(Looper.getMainLooper())) {
         @Override
@@ -61,6 +76,7 @@
     private void handleUiResult(int resultCode, Bundle resultData) {
         switch (resultCode) {
             case UserSelectionDialogResult.RESULT_CODE_DIALOG_COMPLETE_WITH_SELECTION:
+                mStatus = UiStatus.IN_PROGRESS;
                 UserSelectionDialogResult selection = UserSelectionDialogResult
                         .fromResultData(resultData);
                 if (selection != null) {
@@ -70,16 +86,20 @@
                 }
                 break;
             case UserSelectionDialogResult.RESULT_CODE_DIALOG_USER_CANCELED:
+                mStatus = UiStatus.TERMINATED;
                 mCallbacks.onUiCancellation(/* isUserCancellation= */ true);
                 break;
             case UserSelectionDialogResult.RESULT_CODE_CANCELED_AND_LAUNCHED_SETTINGS:
+                mStatus = UiStatus.TERMINATED;
                 mCallbacks.onUiCancellation(/* isUserCancellation= */ false);
                 break;
             case UserSelectionDialogResult.RESULT_CODE_DATA_PARSING_FAILURE:
+                mStatus = UiStatus.TERMINATED;
                 mCallbacks.onUiSelectorInvocationFailure();
                 break;
             default:
                 Slog.i(TAG, "Unknown error code returned from the UI");
+                mStatus = UiStatus.IN_PROGRESS;
                 mCallbacks.onUiSelectorInvocationFailure();
                 break;
         }
@@ -103,6 +123,17 @@
         mContext = context;
         mUserId = userId;
         mCallbacks = callbacks;
+        mStatus = UiStatus.IN_PROGRESS;
+    }
+
+    /** Set status for credential manager UI */
+    public void setStatus(UiStatus status) {
+        mStatus = status;
+    }
+
+    /** Returns status for credential manager UI */
+    public UiStatus getStatus() {
+        return mStatus;
     }
 
     /**
diff --git a/services/credentials/java/com/android/server/credentials/GetRequestSession.java b/services/credentials/java/com/android/server/credentials/GetRequestSession.java
index c0c7be9..2548bd8 100644
--- a/services/credentials/java/com/android/server/credentials/GetRequestSession.java
+++ b/services/credentials/java/com/android/server/credentials/GetRequestSession.java
@@ -45,12 +45,13 @@
         IGetCredentialCallback, GetCredentialResponse>
         implements ProviderSession.ProviderInternalCallback<GetCredentialResponse> {
     private static final String TAG = "GetRequestSession";
-    public GetRequestSession(Context context, int userId, int callingUid,
+    public GetRequestSession(Context context, RequestSession.SessionLifetime sessionCallback,
+            Object lock, int userId, int callingUid,
             IGetCredentialCallback callback, GetCredentialRequest request,
             CallingAppInfo callingAppInfo, CancellationSignal cancellationSignal,
             long startedTimestamp) {
-        super(context, userId, callingUid, request, callback, RequestInfo.TYPE_GET,
-                callingAppInfo, cancellationSignal, startedTimestamp);
+        super(context, sessionCallback, lock, userId, callingUid, request, callback,
+                RequestInfo.TYPE_GET, callingAppInfo, cancellationSignal, startedTimestamp);
         int numTypes = (request.getCredentialOptions().stream()
                 .map(CredentialOption::getType).collect(
                 Collectors.toSet())).size(); // Dedupe type strings
@@ -81,6 +82,7 @@
     @Override
     protected void launchUiWithProviderData(ArrayList<ProviderData> providerDataList) {
         mRequestSessionMetric.collectUiCallStartTime(System.nanoTime());
+        mCredentialManagerUi.setStatus(CredentialManagerUi.UiStatus.USER_INTERACTION);
         try {
             mClientCallback.onPendingIntent(mCredentialManagerUi.createPendingIntent(
                     RequestInfo.newGetRequestInfo(
@@ -88,6 +90,7 @@
                     providerDataList));
         } catch (RemoteException e) {
             mRequestSessionMetric.collectUiReturnedFinalPhase(/*uiReturned=*/ false);
+            mCredentialManagerUi.setStatus(CredentialManagerUi.UiStatus.TERMINATED);
             respondToClientWithErrorAndFinish(
                     GetCredentialException.TYPE_UNKNOWN, "Unable to instantiate selector");
         }
diff --git a/services/credentials/java/com/android/server/credentials/PrepareGetRequestSession.java b/services/credentials/java/com/android/server/credentials/PrepareGetRequestSession.java
index c4e480a..88f3e6c 100644
--- a/services/credentials/java/com/android/server/credentials/PrepareGetRequestSession.java
+++ b/services/credentials/java/com/android/server/credentials/PrepareGetRequestSession.java
@@ -49,14 +49,13 @@
 
     private final IPrepareGetCredentialCallback mPrepareGetCredentialCallback;
 
-    public PrepareGetRequestSession(Context context, int userId, int callingUid,
-            IGetCredentialCallback callback,
-            GetCredentialRequest request,
-            CallingAppInfo callingAppInfo,
-            CancellationSignal cancellationSignal, long startedTimestamp,
-            IPrepareGetCredentialCallback prepareGetCredentialCallback) {
-        super(context, userId, callingUid, callback, request, callingAppInfo, cancellationSignal,
-                startedTimestamp);
+    public PrepareGetRequestSession(Context context,
+            RequestSession.SessionLifetime sessionCallback, Object lock, int userId,
+            int callingUid, IGetCredentialCallback getCredCallback, GetCredentialRequest request,
+            CallingAppInfo callingAppInfo, CancellationSignal cancellationSignal,
+            long startedTimestamp, IPrepareGetCredentialCallback prepareGetCredentialCallback) {
+        super(context, sessionCallback, lock, userId, callingUid, getCredCallback, request,
+                callingAppInfo, cancellationSignal, startedTimestamp);
         int numTypes = (request.getCredentialOptions().stream()
                 .map(CredentialOption::getType).collect(
                         Collectors.toSet())).size(); // Dedupe type strings
diff --git a/services/credentials/java/com/android/server/credentials/ProviderClearSession.java b/services/credentials/java/com/android/server/credentials/ProviderClearSession.java
index 1b736e0..eaf58f1 100644
--- a/services/credentials/java/com/android/server/credentials/ProviderClearSession.java
+++ b/services/credentials/java/com/android/server/credentials/ProviderClearSession.java
@@ -126,7 +126,8 @@
     protected void invokeSession() {
         if (mRemoteCredentialService != null) {
             startCandidateMetrics();
-            mRemoteCredentialService.onClearCredentialState(mProviderRequest, this);
+            mProviderCancellationSignal =
+                    mRemoteCredentialService.onClearCredentialState(mProviderRequest, this);
         }
     }
 }
diff --git a/services/credentials/java/com/android/server/credentials/ProviderCreateSession.java b/services/credentials/java/com/android/server/credentials/ProviderCreateSession.java
index bef045f..c657e3b 100644
--- a/services/credentials/java/com/android/server/credentials/ProviderCreateSession.java
+++ b/services/credentials/java/com/android/server/credentials/ProviderCreateSession.java
@@ -236,7 +236,8 @@
     protected void invokeSession() {
         if (mRemoteCredentialService != null) {
             startCandidateMetrics();
-            mRemoteCredentialService.onCreateCredential(mProviderRequest, this);
+            mProviderCancellationSignal =
+                    mRemoteCredentialService.onCreateCredential(mProviderRequest, this);
         }
     }
 
diff --git a/services/credentials/java/com/android/server/credentials/ProviderGetSession.java b/services/credentials/java/com/android/server/credentials/ProviderGetSession.java
index 427a894..9c9c0c2 100644
--- a/services/credentials/java/com/android/server/credentials/ProviderGetSession.java
+++ b/services/credentials/java/com/android/server/credentials/ProviderGetSession.java
@@ -302,7 +302,9 @@
     protected void invokeSession() {
         if (mRemoteCredentialService != null) {
             startCandidateMetrics();
-            mRemoteCredentialService.onBeginGetCredential(mProviderRequest, this);
+            mProviderCancellationSignal =
+                    mRemoteCredentialService.onBeginGetCredential(mProviderRequest, this);
+            boolean foundSig = mProviderCancellationSignal == null;
         }
     }
 
diff --git a/services/credentials/java/com/android/server/credentials/ProviderSession.java b/services/credentials/java/com/android/server/credentials/ProviderSession.java
index 8c0e1c1..d165756 100644
--- a/services/credentials/java/com/android/server/credentials/ProviderSession.java
+++ b/services/credentials/java/com/android/server/credentials/ProviderSession.java
@@ -30,6 +30,7 @@
 import android.os.ICancellationSignal;
 import android.os.RemoteException;
 import android.util.Log;
+import android.util.Slog;
 
 import com.android.server.credentials.metrics.ProviderSessionMetric;
 
@@ -189,7 +190,7 @@
             }
             setStatus(Status.CANCELED);
         } catch (RemoteException e) {
-            Log.i(TAG, "Issue while cancelling provider session: " + e.getMessage());
+            Slog.e(TAG, "Issue while cancelling provider session: ", e);
         }
     }
 
diff --git a/services/credentials/java/com/android/server/credentials/RequestSession.java b/services/credentials/java/com/android/server/credentials/RequestSession.java
index 04c4bc4..ed175ed 100644
--- a/services/credentials/java/com/android/server/credentials/RequestSession.java
+++ b/services/credentials/java/com/android/server/credentials/RequestSession.java
@@ -20,6 +20,7 @@
 import android.annotation.UserIdInt;
 import android.content.ComponentName;
 import android.content.Context;
+import android.content.Intent;
 import android.credentials.CredentialProviderInfo;
 import android.credentials.ui.ProviderData;
 import android.credentials.ui.UserSelectionDialogResult;
@@ -29,8 +30,10 @@
 import android.os.IBinder;
 import android.os.Looper;
 import android.os.RemoteException;
+import android.os.UserHandle;
 import android.service.credentials.CallingAppInfo;
 import android.util.Log;
+import android.util.Slog;
 
 import com.android.internal.R;
 import com.android.server.credentials.metrics.ApiName;
@@ -39,8 +42,8 @@
 import com.android.server.credentials.metrics.RequestSessionMetric;
 
 import java.util.ArrayList;
-import java.util.HashMap;
 import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
 
 /**
  * Base class of a request session, that listens to UI events. This class must be extended
@@ -49,6 +52,11 @@
 abstract class RequestSession<T, U, V> implements CredentialManagerUi.CredentialManagerUiCallback {
     private static final String TAG = "RequestSession";
 
+    public interface SessionLifetime {
+        /** Called when the user makes a selection. */
+        void onFinishRequestSession(@UserIdInt int userId, IBinder token);
+    }
+
     // TODO: Revise access levels of attributes
     @NonNull
     protected final T mClientRequest;
@@ -72,10 +80,14 @@
     @NonNull
     protected final CancellationSignal mCancellationSignal;
 
-    protected final Map<String, ProviderSession> mProviders = new HashMap<>();
+    protected final Map<String, ProviderSession> mProviders = new ConcurrentHashMap<>();
     protected final RequestSessionMetric mRequestSessionMetric = new RequestSessionMetric();
     protected final String mHybridService;
 
+    protected final Object mLock;
+
+    protected final SessionLifetime mSessionCallback;
+
     @NonNull
     protected RequestSessionStatus mRequestSessionStatus =
             RequestSessionStatus.IN_PROGRESS;
@@ -91,11 +103,15 @@
     }
 
     protected RequestSession(@NonNull Context context,
-            @UserIdInt int userId, int callingUid, @NonNull T clientRequest, U clientCallback,
+            RequestSession.SessionLifetime sessionCallback,
+            Object lock, @UserIdInt int userId, int callingUid,
+            @NonNull T clientRequest, U clientCallback,
             @NonNull String requestType,
             CallingAppInfo callingAppInfo,
             CancellationSignal cancellationSignal, long timestampStarted) {
         mContext = context;
+        mLock = lock;
+        mSessionCallback = sessionCallback;
         mUserId = userId;
         mCallingUid = callingUid;
         mClientRequest = clientRequest;
@@ -111,6 +127,32 @@
                 R.string.config_defaultCredentialManagerHybridService);
         mRequestSessionMetric.collectInitialPhaseMetricInfo(timestampStarted, mRequestId,
                 mCallingUid, ApiName.getMetricCodeFromRequestInfo(mRequestType));
+        setCancellationListener();
+    }
+
+    private void setCancellationListener() {
+        mCancellationSignal.setOnCancelListener(
+                () -> {
+                    boolean isUiActive = maybeCancelUi();
+                    finishSession(!isUiActive);
+                }
+        );
+    }
+
+    private boolean maybeCancelUi() {
+        if (mCredentialManagerUi.getStatus()
+                == CredentialManagerUi.UiStatus.USER_INTERACTION) {
+            final long originalCallingUidToken = Binder.clearCallingIdentity();
+            try {
+                mContext.startActivityAsUser(mCredentialManagerUi.createCancelIntent(
+                                mRequestId, mClientAppInfo.getPackageName())
+                        .addFlags(Intent.FLAG_ACTIVITY_NEW_TASK), UserHandle.of(mUserId));
+                return true;
+            } finally {
+                Binder.restoreCallingIdentity(originalCallingUidToken);
+            }
+        }
+        return false;
     }
 
     public abstract ProviderSession initiateProviderSession(CredentialProviderInfo providerInfo,
@@ -154,12 +196,19 @@
     }
 
     protected void finishSession(boolean propagateCancellation) {
-        Log.i(TAG, "finishing session");
+        Slog.d(TAG, "finishing session with propagateCancellation " + propagateCancellation);
         if (propagateCancellation) {
             mProviders.values().forEach(ProviderSession::cancelProviderRemoteSession);
         }
         mRequestSessionStatus = RequestSessionStatus.COMPLETE;
         mProviders.clear();
+        clearRequestSessionLocked();
+    }
+
+    private void clearRequestSessionLocked() {
+        synchronized (mLock) {
+            mSessionCallback.onFinishRequestSession(mUserId, mRequestId);
+        }
     }
 
     protected boolean isAnyProviderPending() {