Merge "Cache callbacks that can be ready before the service is set" into main
diff --git a/flags/telecom_call_flags.aconfig b/flags/telecom_call_flags.aconfig
index 5cb9dbd..40aa8b2 100644
--- a/flags/telecom_call_flags.aconfig
+++ b/flags/telecom_call_flags.aconfig
@@ -7,4 +7,11 @@
   namespace: "telecom"
   description: "verify connection service callbacks via a transaction"
   bug: "309541257"
+}
+
+flag {
+  name: "cache_call_audio_callbacks"
+  namespace: "telecom"
+  description: "cache call audio callbacks if the service is not available and execute when set"
+  bug: "321369729"
 }
\ No newline at end of file
diff --git a/src/com/android/server/telecom/CachedAvailableEndpointsChange.java b/src/com/android/server/telecom/CachedAvailableEndpointsChange.java
new file mode 100644
index 0000000..232f00d
--- /dev/null
+++ b/src/com/android/server/telecom/CachedAvailableEndpointsChange.java
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.telecom;
+
+import android.telecom.CallEndpoint;
+
+import java.util.Objects;
+import java.util.Set;
+
+public class CachedAvailableEndpointsChange implements CachedCallback {
+    public static final String ID = CachedAvailableEndpointsChange.class.getSimpleName();
+    Set<CallEndpoint> mAvailableEndpoints;
+
+    public Set<CallEndpoint> getAvailableEndpoints() {
+        return mAvailableEndpoints;
+    }
+
+    public CachedAvailableEndpointsChange(Set<CallEndpoint> endpoints) {
+        mAvailableEndpoints = endpoints;
+    }
+
+    @Override
+    public void executeCallback(CallSourceService service, Call call) {
+        service.onAvailableCallEndpointsChanged(call, mAvailableEndpoints);
+    }
+
+    @Override
+    public String getCallbackId() {
+        return ID;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hashCode(mAvailableEndpoints);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (obj == null) {
+            return false;
+        }
+        if (!(obj instanceof CachedAvailableEndpointsChange other)) {
+            return false;
+        }
+        if (mAvailableEndpoints.size() != other.mAvailableEndpoints.size()) {
+            return false;
+        }
+        for (CallEndpoint e : mAvailableEndpoints) {
+            if (!other.getAvailableEndpoints().contains(e)) {
+                return false;
+            }
+        }
+        return true;
+    }
+}
+
diff --git a/src/com/android/server/telecom/CachedCallback.java b/src/com/android/server/telecom/CachedCallback.java
new file mode 100644
index 0000000..88dad07
--- /dev/null
+++ b/src/com/android/server/telecom/CachedCallback.java
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.telecom;
+
+/**
+ * Any android.telecom.Call service (e.g. ConnectionService, TransactionalService) that declares
+ * a {@link CallSourceService} should implement this interface in order to cache the callback.
+ * The callback will be executed once the service is set.
+ */
+public interface CachedCallback {
+    /**
+     * This method executes the callback that was cached because the service was not available
+     * at the time the callback was ready.
+     *
+     * @param service that was recently set (e.g. ConnectionService)
+     * @param call    that had a null service at the time the callback was ready. The service is now
+     *                non-null in the call and can be executed/
+     */
+    void executeCallback(CallSourceService service, Call call);
+
+    /**
+     * This method is helpful for caching the callbacks.  If the callback is called multiple times
+     * while the service is not set, ONLY the last callback should be sent to the client since the
+     * last callback is the most relevant
+     *
+     * @return the callback id that is used in a map to only store the last callback value
+     */
+    String getCallbackId();
+}
diff --git a/src/com/android/server/telecom/CachedCurrentEndpointChange.java b/src/com/android/server/telecom/CachedCurrentEndpointChange.java
new file mode 100644
index 0000000..0d5bac9
--- /dev/null
+++ b/src/com/android/server/telecom/CachedCurrentEndpointChange.java
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.telecom;
+
+import android.telecom.CallEndpoint;
+
+import java.util.Objects;
+
+public class CachedCurrentEndpointChange implements CachedCallback {
+    public static final String ID = CachedCurrentEndpointChange.class.getSimpleName();
+    CallEndpoint mCurrentCallEndpoint;
+
+    public CallEndpoint getCurrentCallEndpoint() {
+        return mCurrentCallEndpoint;
+    }
+
+    public CachedCurrentEndpointChange(CallEndpoint callEndpoint) {
+        mCurrentCallEndpoint = callEndpoint;
+    }
+
+    @Override
+    public void executeCallback(CallSourceService service, Call call) {
+        service.onCallEndpointChanged(call, mCurrentCallEndpoint);
+    }
+
+    @Override
+    public String getCallbackId() {
+        return ID;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hashCode(mCurrentCallEndpoint);
+    }
+
+    @Override
+    public boolean equals(Object obj){
+        if (obj == null) {
+            return false;
+        }
+        if (!(obj instanceof CachedCurrentEndpointChange other)) {
+            return false;
+        }
+        return mCurrentCallEndpoint.equals(other.mCurrentCallEndpoint);
+    }
+}
+
diff --git a/src/com/android/server/telecom/CachedMuteStateChange.java b/src/com/android/server/telecom/CachedMuteStateChange.java
new file mode 100644
index 0000000..45cbfaa
--- /dev/null
+++ b/src/com/android/server/telecom/CachedMuteStateChange.java
@@ -0,0 +1,57 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.telecom;
+
+public class CachedMuteStateChange implements CachedCallback {
+    public static final String ID = CachedMuteStateChange.class.getSimpleName();
+    boolean mIsMuted;
+
+    public boolean isMuted() {
+        return mIsMuted;
+    }
+
+    public CachedMuteStateChange(boolean isMuted) {
+        mIsMuted = isMuted;
+    }
+
+    @Override
+    public void executeCallback(CallSourceService service, Call call) {
+        service.onMuteStateChanged(call, mIsMuted);
+    }
+
+    @Override
+    public String getCallbackId() {
+        return ID;
+    }
+
+    @Override
+    public int hashCode() {
+        return Boolean.hashCode(mIsMuted);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (obj == null) {
+            return false;
+        }
+        if (!(obj instanceof CachedMuteStateChange other)) {
+            return false;
+        }
+        return mIsMuted == other.mIsMuted;
+    }
+}
+
diff --git a/src/com/android/server/telecom/Call.java b/src/com/android/server/telecom/Call.java
index cdf7cd9..c4882a0 100644
--- a/src/com/android/server/telecom/Call.java
+++ b/src/com/android/server/telecom/Call.java
@@ -86,6 +86,7 @@
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Date;
+import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Locale;
@@ -833,6 +834,16 @@
      */
     private CompletableFuture<Boolean> mBtIcsFuture;
 
+    Map<String, CachedCallback> mCachedServiceCallbacks = new HashMap<>();
+
+    public void cacheServiceCallback(CachedCallback callback) {
+        mCachedServiceCallbacks.put(callback.getCallbackId(), callback);
+    }
+
+    public Map<String, CachedCallback> getCachedServiceCallbacks() {
+        return mCachedServiceCallbacks;
+    }
+
     private FeatureFlags mFlags;
 
     /**
@@ -2001,7 +2012,27 @@
     }
 
     public void setTransactionServiceWrapper(TransactionalServiceWrapper service) {
+        Log.i(this, "setTransactionServiceWrapper: service=[%s]", service);
         mTransactionalService = service;
+        processCachedCallbacks(service);
+    }
+
+    private void processCachedCallbacks(CallSourceService service) {
+        if(mFlags.cacheCallAudioCallbacks()) {
+            for (CachedCallback callback : mCachedServiceCallbacks.values()) {
+                callback.executeCallback(service, this);
+            }
+            // clear list for memory cleanup purposes. The Service should never be reset
+            mCachedServiceCallbacks.clear();
+        }
+    }
+
+    public CallSourceService getService() {
+        if (isTransactionalCall()) {
+            return mTransactionalService;
+        } else {
+            return mConnectionService;
+        }
     }
 
     public TransactionalServiceWrapper getTransactionServiceWrapper() {
@@ -2408,6 +2439,7 @@
 
     @VisibleForTesting
     public void setConnectionService(ConnectionServiceWrapper service) {
+        Log.i(this, "setConnectionService: service=[%s]", service);
         setConnectionService(service, null);
     }
 
@@ -2430,6 +2462,7 @@
         mConnectionService = service;
         mAnalytics.setCallConnectionService(service.getComponentName().flattenToShortString());
         mConnectionService.addCall(this);
+        processCachedCallbacks(service);
     }
 
     /**
diff --git a/src/com/android/server/telecom/CallEndpointController.java b/src/com/android/server/telecom/CallEndpointController.java
index 4738cd4..49c0d51 100644
--- a/src/com/android/server/telecom/CallEndpointController.java
+++ b/src/com/android/server/telecom/CallEndpointController.java
@@ -27,6 +27,7 @@
 import android.telecom.Log;
 
 import com.android.internal.annotations.VisibleForTesting;
+import com.android.server.telecom.flags.FeatureFlags;
 
 import java.util.HashMap;
 import java.util.Map;
@@ -49,6 +50,7 @@
 
     private final Context mContext;
     private final CallsManager mCallsManager;
+    private final FeatureFlags mFeatureFlags;
     private final HashMap<Integer, Integer> mRouteToTypeMap;
     private final HashMap<Integer, Integer> mTypeToRouteMap;
     private final Map<ParcelUuid, String> mBluetoothAddressMap = new HashMap<>();
@@ -57,10 +59,10 @@
     private ParcelUuid mRequestedEndpointId;
     private CompletableFuture<Integer> mPendingChangeRequest;
 
-    public CallEndpointController(Context context, CallsManager callsManager) {
+    public CallEndpointController(Context context, CallsManager callsManager, FeatureFlags flags) {
         mContext = context;
         mCallsManager = callsManager;
-
+        mFeatureFlags = flags;
         mRouteToTypeMap = new HashMap<>(5);
         mRouteToTypeMap.put(CallAudioState.ROUTE_EARPIECE, CallEndpoint.TYPE_EARPIECE);
         mRouteToTypeMap.put(CallAudioState.ROUTE_BLUETOOTH, CallEndpoint.TYPE_BLUETOOTH);
@@ -197,43 +199,91 @@
 
         Set<Call> calls = mCallsManager.getTrackedCalls();
         for (Call call : calls) {
-            if (call != null && call.getConnectionService() != null) {
-                call.getConnectionService().onCallEndpointChanged(call, mActiveCallEndpoint);
-            } else if (call != null && call.getTransactionServiceWrapper() != null) {
-                call.getTransactionServiceWrapper()
-                        .onCallEndpointChanged(call, mActiveCallEndpoint);
+            if (mFeatureFlags.cacheCallAudioCallbacks()) {
+                onCallEndpointChangedOrCache(call);
+            } else {
+                if (call != null && call.getConnectionService() != null) {
+                    call.getConnectionService().onCallEndpointChanged(call, mActiveCallEndpoint);
+                } else if (call != null && call.getTransactionServiceWrapper() != null) {
+                    call.getTransactionServiceWrapper()
+                            .onCallEndpointChanged(call, mActiveCallEndpoint);
+                }
             }
         }
     }
 
+    private void onCallEndpointChangedOrCache(Call call) {
+        if (call == null) {
+            return;
+        }
+        CallSourceService service = call.getService();
+        if (service != null) {
+            service.onCallEndpointChanged(call, mActiveCallEndpoint);
+        } else {
+            call.cacheServiceCallback(new CachedCurrentEndpointChange(mActiveCallEndpoint));
+        }
+    }
+
     private void notifyAvailableCallEndpointsChange() {
         mCallsManager.updateAvailableCallEndpoints(mAvailableCallEndpoints);
 
         Set<Call> calls = mCallsManager.getTrackedCalls();
         for (Call call : calls) {
-            if (call != null && call.getConnectionService() != null) {
-                call.getConnectionService().onAvailableCallEndpointsChanged(call,
-                        mAvailableCallEndpoints);
-            } else if (call != null && call.getTransactionServiceWrapper() != null) {
-                call.getTransactionServiceWrapper()
-                        .onAvailableCallEndpointsChanged(call, mAvailableCallEndpoints);
+            if (mFeatureFlags.cacheCallAudioCallbacks()) {
+                onAvailableEndpointsChangedOrCache(call);
+            } else {
+                if (call != null && call.getConnectionService() != null) {
+                    call.getConnectionService().onAvailableCallEndpointsChanged(call,
+                            mAvailableCallEndpoints);
+                } else if (call != null && call.getTransactionServiceWrapper() != null) {
+                    call.getTransactionServiceWrapper().onAvailableCallEndpointsChanged(call,
+                            mAvailableCallEndpoints);
+                }
             }
         }
     }
 
+    private void onAvailableEndpointsChangedOrCache(Call call) {
+        if (call == null) {
+            return;
+        }
+        CallSourceService service = call.getService();
+        if (service != null) {
+            service.onAvailableCallEndpointsChanged(call, mAvailableCallEndpoints);
+        } else {
+            call.cacheServiceCallback(new CachedAvailableEndpointsChange(mAvailableCallEndpoints));
+        }
+    }
+
     private void notifyMuteStateChange(boolean isMuted) {
         mCallsManager.updateMuteState(isMuted);
 
         Set<Call> calls = mCallsManager.getTrackedCalls();
         for (Call call : calls) {
-            if (call != null && call.getConnectionService() != null) {
-                call.getConnectionService().onMuteStateChanged(call, isMuted);
-            } else if (call != null && call.getTransactionServiceWrapper() != null) {
-                call.getTransactionServiceWrapper().onMuteStateChanged(call, isMuted);
+            if (mFeatureFlags.cacheCallAudioCallbacks()) {
+                onMuteStateChangedOrCache(call, isMuted);
+            } else {
+                if (call != null && call.getConnectionService() != null) {
+                    call.getConnectionService().onMuteStateChanged(call, isMuted);
+                } else if (call != null && call.getTransactionServiceWrapper() != null) {
+                    call.getTransactionServiceWrapper().onMuteStateChanged(call, isMuted);
+                }
             }
         }
     }
 
+    private void onMuteStateChangedOrCache(Call call, boolean isMuted){
+        if (call == null) {
+            return;
+        }
+        CallSourceService service = call.getService();
+        if (service != null) {
+            service.onMuteStateChanged(call, isMuted);
+        } else {
+            call.cacheServiceCallback(new CachedMuteStateChange(isMuted));
+        }
+    }
+
     private void createAvailableCallEndpoints(CallAudioState state) {
         Set<CallEndpoint> newAvailableEndpoints = new HashSet<>();
         Map<ParcelUuid, String> newBluetoothDevices = new HashMap<>();
diff --git a/src/com/android/server/telecom/CallSourceService.java b/src/com/android/server/telecom/CallSourceService.java
new file mode 100644
index 0000000..132118b
--- /dev/null
+++ b/src/com/android/server/telecom/CallSourceService.java
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.telecom;
+
+import android.telecom.CallEndpoint;
+
+import java.util.Set;
+
+/**
+ * android.telecom.Call backed Services (i.e. ConnectionService, TransactionalService, etc.) that
+ * have callbacks that can be executed before the service is set (within the Call object) should
+ * implement this interface in order for clients to receive the callback.
+ *
+ * It has been shown that clients can miss important callback information (e.g. available audio
+ * endpoints) if the service is null within the call at the time the callback is sent.  This is a
+ * way to eliminate the timing issue and for clients to receive all callbacks.
+ */
+public interface CallSourceService {
+    void onMuteStateChanged(Call activeCall, boolean isMuted);
+
+    void onCallEndpointChanged(Call activeCall, CallEndpoint callEndpoint);
+
+    void onAvailableCallEndpointsChanged(Call activeCall, Set<CallEndpoint> availableCallEndpoints);
+}
diff --git a/src/com/android/server/telecom/ConnectionServiceWrapper.java b/src/com/android/server/telecom/ConnectionServiceWrapper.java
index 53da8ff..54ad7e6 100644
--- a/src/com/android/server/telecom/ConnectionServiceWrapper.java
+++ b/src/com/android/server/telecom/ConnectionServiceWrapper.java
@@ -33,7 +33,6 @@
 import android.os.CancellationSignal;
 import android.os.IBinder;
 import android.os.ParcelFileDescriptor;
-import android.os.Process;
 import android.os.RemoteException;
 import android.os.ResultReceiver;
 import android.os.UserHandle;
@@ -66,7 +65,6 @@
 import com.android.internal.telecom.RemoteServiceCallback;
 import com.android.internal.util.Preconditions;
 import com.android.server.telecom.flags.FeatureFlags;
-import com.android.server.telecom.flags.Flags;
 
 import java.util.ArrayList;
 import java.util.Collection;
@@ -90,7 +88,7 @@
  */
 @VisibleForTesting
 public class ConnectionServiceWrapper extends ServiceBinder implements
-        ConnectionServiceFocusManager.ConnectionServiceFocus {
+        ConnectionServiceFocusManager.ConnectionServiceFocus, CallSourceService {
 
     private static final String TELECOM_ABBREVIATION = "cast";
     private CompletableFuture<Pair<Integer, Location>> mQueryLocationFuture = null;
@@ -1953,6 +1951,7 @@
 
     /** @see IConnectionService#onCallEndpointChanged(String, CallEndpoint, Session.Info) */
     @VisibleForTesting(visibility = VisibleForTesting.Visibility.PACKAGE)
+    @Override
     public void onCallEndpointChanged(Call activeCall, CallEndpoint callEndpoint) {
         final String callId = mCallIdMapper.getCallId(activeCall);
         if (callId != null && isServiceValid("onCallEndpointChanged")) {
@@ -1968,6 +1967,7 @@
 
     /** @see IConnectionService#onAvailableCallEndpointsChanged(String, List, Session.Info) */
     @VisibleForTesting(visibility = VisibleForTesting.Visibility.PACKAGE)
+    @Override
     public void onAvailableCallEndpointsChanged(Call activeCall,
             Set<CallEndpoint> availableCallEndpoints) {
         final String callId = mCallIdMapper.getCallId(activeCall);
@@ -1986,6 +1986,7 @@
 
     /** @see IConnectionService#onMuteStateChanged(String, boolean, Session.Info) */
     @VisibleForTesting(visibility = VisibleForTesting.Visibility.PACKAGE)
+    @Override
     public void onMuteStateChanged(Call activeCall, boolean isMuted) {
         final String callId = mCallIdMapper.getCallId(activeCall);
         if (callId != null && isServiceValid("onMuteStateChanged")) {
diff --git a/src/com/android/server/telecom/TelecomServiceImpl.java b/src/com/android/server/telecom/TelecomServiceImpl.java
index fe7c0ae..d8d7584 100644
--- a/src/com/android/server/telecom/TelecomServiceImpl.java
+++ b/src/com/android/server/telecom/TelecomServiceImpl.java
@@ -246,6 +246,7 @@
                                                 callEventCallback, mCallsManager, call);
 
                         call.setTransactionServiceWrapper(serviceWrapper);
+
                         if (mFeatureFlags.transactionalVideoState()) {
                             call.setTransactionalCallSupportsVideoCalling(callAttributes);
                         }
diff --git a/src/com/android/server/telecom/TelecomSystem.java b/src/com/android/server/telecom/TelecomSystem.java
index 91a34cf..6d6fe96 100644
--- a/src/com/android/server/telecom/TelecomSystem.java
+++ b/src/com/android/server/telecom/TelecomSystem.java
@@ -304,7 +304,7 @@
                 @Override
                 public CallEndpointController create(Context context, SyncRoot lock,
                         CallsManager callsManager) {
-                    return new CallEndpointController(context, callsManager);
+                    return new CallEndpointController(context, callsManager, featureFlags);
                 }
             };
 
diff --git a/src/com/android/server/telecom/TransactionalServiceWrapper.java b/src/com/android/server/telecom/TransactionalServiceWrapper.java
index d497c6a..df2f9af 100644
--- a/src/com/android/server/telecom/TransactionalServiceWrapper.java
+++ b/src/com/android/server/telecom/TransactionalServiceWrapper.java
@@ -62,7 +62,7 @@
  * on a per-client basis which is tied to a {@link PhoneAccountHandle}
  */
 public class TransactionalServiceWrapper implements
-        ConnectionServiceFocusManager.ConnectionServiceFocus {
+        ConnectionServiceFocusManager.ConnectionServiceFocus, CallSourceService {
     private static final String TAG = TransactionalServiceWrapper.class.getSimpleName();
 
     // CallControl : Client (ex. voip app) --> Telecom
@@ -552,6 +552,7 @@
         }
     }
 
+    @Override
     public void onCallEndpointChanged(Call call, CallEndpoint endpoint) {
         if (call != null) {
             try {
@@ -561,6 +562,7 @@
         }
     }
 
+    @Override
     public void onAvailableCallEndpointsChanged(Call call, Set<CallEndpoint> endpoints) {
         if (call != null) {
             try {
@@ -571,6 +573,7 @@
         }
     }
 
+    @Override
     public void onMuteStateChanged(Call call, boolean isMuted) {
         if (call != null) {
             try {
diff --git a/tests/src/com/android/server/telecom/tests/CallEndpointControllerTest.java b/tests/src/com/android/server/telecom/tests/CallEndpointControllerTest.java
index 9101a19..b8b9560 100644
--- a/tests/src/com/android/server/telecom/tests/CallEndpointControllerTest.java
+++ b/tests/src/com/android/server/telecom/tests/CallEndpointControllerTest.java
@@ -40,6 +40,7 @@
 import com.android.server.telecom.CallEndpointController;
 import com.android.server.telecom.CallsManager;
 import com.android.server.telecom.ConnectionServiceWrapper;
+import com.android.server.telecom.flags.FeatureFlags;
 
 import org.junit.Before;
 import org.junit.After;
@@ -101,7 +102,10 @@
     @Before
     public void setUp() throws Exception {
         super.setUp();
-        mCallEndpointController = new CallEndpointController(mMockContext, mCallsManager);
+        mCallEndpointController = new CallEndpointController(
+                mMockContext,
+                mCallsManager,
+                mFeatureFlags);
         doReturn(new HashSet<>(Arrays.asList(mCall))).when(mCallsManager).getTrackedCalls();
         doReturn(mConnectionService).when(mCall).getConnectionService();
         doReturn(mCallAudioManager).when(mCallsManager).getCallAudioManager();
diff --git a/tests/src/com/android/server/telecom/tests/CallTest.java b/tests/src/com/android/server/telecom/tests/CallTest.java
index e06938d..d3f220c 100644
--- a/tests/src/com/android/server/telecom/tests/CallTest.java
+++ b/tests/src/com/android/server/telecom/tests/CallTest.java
@@ -22,6 +22,7 @@
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
@@ -30,6 +31,7 @@
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 import android.content.ComponentName;
 import android.content.Intent;
@@ -40,6 +42,7 @@
 import android.os.Bundle;
 import android.os.UserHandle;
 import android.telecom.CallAttributes;
+import android.telecom.CallEndpoint;
 import android.telecom.CallerInfo;
 import android.telecom.Connection;
 import android.telecom.DisconnectCause;
@@ -56,6 +59,9 @@
 import androidx.test.ext.junit.runners.AndroidJUnit4;
 import androidx.test.filters.SmallTest;
 
+import com.android.server.telecom.CachedAvailableEndpointsChange;
+import com.android.server.telecom.CachedCurrentEndpointChange;
+import com.android.server.telecom.CachedMuteStateChange;
 import com.android.server.telecom.Call;
 import com.android.server.telecom.CallIdMapper;
 import com.android.server.telecom.CallState;
@@ -78,6 +84,7 @@
 import org.mockito.Mockito;
 
 import java.util.Collections;
+import java.util.Set;
 
 @RunWith(AndroidJUnit4.class)
 public class CallTest extends TelecomTestCase {
@@ -137,6 +144,148 @@
         assertTrue(call.hasGoneActiveBefore());
     }
 
+    @Test
+    public void testMultipleCachedMuteStateChanges() {
+        when(mFeatureFlags.cacheCallAudioCallbacks()).thenReturn(true);
+        TransactionalServiceWrapper tsw = Mockito.mock(TransactionalServiceWrapper.class);
+        Call call = createCall("1", Call.CALL_DIRECTION_INCOMING);
+
+        assertNull(call.getTransactionServiceWrapper());
+
+        call.cacheServiceCallback(new CachedMuteStateChange(true));
+        assertEquals(1, call.getCachedServiceCallbacks().size());
+
+        call.cacheServiceCallback(new CachedMuteStateChange(false));
+        assertEquals(1, call.getCachedServiceCallbacks().size());
+
+        CachedMuteStateChange currentCacheMuteState = (CachedMuteStateChange) call
+                .getCachedServiceCallbacks()
+                .get(CachedMuteStateChange.ID);
+
+        assertFalse(currentCacheMuteState.isMuted());
+
+        call.setTransactionServiceWrapper(tsw);
+        verify(tsw, times(1)).onMuteStateChanged(any(), eq(false));
+        assertEquals(0, call.getCachedServiceCallbacks().size());
+    }
+
+    @Test
+    public void testMultipleCachedCurrentEndpointChanges() {
+        when(mFeatureFlags.cacheCallAudioCallbacks()).thenReturn(true);
+        TransactionalServiceWrapper tsw = Mockito.mock(TransactionalServiceWrapper.class);
+        CallEndpoint earpiece = Mockito.mock(CallEndpoint.class);
+        CallEndpoint speaker = Mockito.mock(CallEndpoint.class);
+        when(earpiece.getEndpointType()).thenReturn(CallEndpoint.TYPE_EARPIECE);
+        when(speaker.getEndpointType()).thenReturn(CallEndpoint.TYPE_SPEAKER);
+
+        Call call = createCall("1", Call.CALL_DIRECTION_INCOMING);
+
+        assertNull(call.getTransactionServiceWrapper());
+
+        call.cacheServiceCallback(new CachedCurrentEndpointChange(earpiece));
+        assertEquals(1, call.getCachedServiceCallbacks().size());
+
+        call.cacheServiceCallback(new CachedCurrentEndpointChange(speaker));
+        assertEquals(1, call.getCachedServiceCallbacks().size());
+
+        CachedCurrentEndpointChange currentEndpointChange = (CachedCurrentEndpointChange) call
+                .getCachedServiceCallbacks()
+                .get(CachedCurrentEndpointChange.ID);
+
+        assertEquals(CallEndpoint.TYPE_SPEAKER,
+                currentEndpointChange.getCurrentCallEndpoint().getEndpointType());
+
+        call.setTransactionServiceWrapper(tsw);
+        verify(tsw, times(1)).onCallEndpointChanged(any(), any());
+        assertEquals(0, call.getCachedServiceCallbacks().size());
+    }
+
+    @Test
+    public void testMultipleCachedAvailableEndpointChanges() {
+        when(mFeatureFlags.cacheCallAudioCallbacks()).thenReturn(true);
+        TransactionalServiceWrapper tsw = Mockito.mock(TransactionalServiceWrapper.class);
+        CallEndpoint earpiece = Mockito.mock(CallEndpoint.class);
+        CallEndpoint bluetooth = Mockito.mock(CallEndpoint.class);
+        Set<CallEndpoint> initialSet = Set.of(earpiece);
+        Set<CallEndpoint> finalSet = Set.of(earpiece, bluetooth);
+        when(earpiece.getEndpointType()).thenReturn(CallEndpoint.TYPE_EARPIECE);
+        when(bluetooth.getEndpointType()).thenReturn(CallEndpoint.TYPE_BLUETOOTH);
+
+        Call call = createCall("1", Call.CALL_DIRECTION_INCOMING);
+
+        assertNull(call.getTransactionServiceWrapper());
+
+        call.cacheServiceCallback(new CachedAvailableEndpointsChange(initialSet));
+        assertEquals(1, call.getCachedServiceCallbacks().size());
+
+        call.cacheServiceCallback(new CachedAvailableEndpointsChange(finalSet));
+        assertEquals(1, call.getCachedServiceCallbacks().size());
+
+        CachedAvailableEndpointsChange availableEndpoints = (CachedAvailableEndpointsChange) call
+                .getCachedServiceCallbacks()
+                .get(CachedAvailableEndpointsChange.ID);
+
+        assertEquals(2, availableEndpoints.getAvailableEndpoints().size());
+
+        call.setTransactionServiceWrapper(tsw);
+        verify(tsw, times(1)).onAvailableCallEndpointsChanged(any(), any());
+        assertEquals(0, call.getCachedServiceCallbacks().size());
+    }
+
+    /**
+     * verify that if multiple types of cached callbacks are added to the call, the call executes
+     * all the callbacks once the service is set.
+     */
+    @Test
+    public void testAllCachedCallbacks() {
+        when(mFeatureFlags.cacheCallAudioCallbacks()).thenReturn(true);
+        TransactionalServiceWrapper tsw = Mockito.mock(TransactionalServiceWrapper.class);
+        CallEndpoint earpiece = Mockito.mock(CallEndpoint.class);
+        CallEndpoint bluetooth = Mockito.mock(CallEndpoint.class);
+        Set<CallEndpoint> availableEndpointsSet = Set.of(earpiece, bluetooth);
+        when(earpiece.getEndpointType()).thenReturn(CallEndpoint.TYPE_EARPIECE);
+        when(bluetooth.getEndpointType()).thenReturn(CallEndpoint.TYPE_BLUETOOTH);
+        Call call = createCall("1", Call.CALL_DIRECTION_INCOMING);
+
+        // The call should have a null service so that callbacks are cached
+        assertNull(call.getTransactionServiceWrapper());
+
+        // add cached callbacks
+        call.cacheServiceCallback(new CachedMuteStateChange(false));
+        assertEquals(1, call.getCachedServiceCallbacks().size());
+        call.cacheServiceCallback(new CachedCurrentEndpointChange(earpiece));
+        assertEquals(2, call.getCachedServiceCallbacks().size());
+        call.cacheServiceCallback(new CachedAvailableEndpointsChange(availableEndpointsSet));
+        assertEquals(3, call.getCachedServiceCallbacks().size());
+
+        // verify the cached callbacks are stored properly within the cache map and the values
+        // can be evaluated
+        CachedMuteStateChange currentCacheMuteState = (CachedMuteStateChange) call
+                .getCachedServiceCallbacks()
+                .get(CachedMuteStateChange.ID);
+        CachedCurrentEndpointChange currentEndpointChange = (CachedCurrentEndpointChange) call
+                .getCachedServiceCallbacks()
+                .get(CachedCurrentEndpointChange.ID);
+        CachedAvailableEndpointsChange availableEndpoints = (CachedAvailableEndpointsChange) call
+                .getCachedServiceCallbacks()
+                .get(CachedAvailableEndpointsChange.ID);
+        assertFalse(currentCacheMuteState.isMuted());
+        assertEquals(CallEndpoint.TYPE_EARPIECE,
+                currentEndpointChange.getCurrentCallEndpoint().getEndpointType());
+        assertEquals(2, availableEndpoints.getAvailableEndpoints().size());
+
+        // set the service to a non-null value
+        call.setTransactionServiceWrapper(tsw);
+
+        // ensure the cached callbacks were executed
+        verify(tsw, times(1)).onMuteStateChanged(any(), anyBoolean());
+        verify(tsw, times(1)).onCallEndpointChanged(any(), any());
+        verify(tsw, times(1)).onAvailableCallEndpointsChanged(any(), any());
+
+        // the cache map should be cleared
+        assertEquals(0, call.getCachedServiceCallbacks().size());
+    }
+
     /**
      * Basic tests to check which call states are considered transitory.
      */