Unbind CS if connection is not created within 15 seconds.

This CL adds a check to ensure that connection creation occurs within 15 seconds after binding to that ConnectionService. If the connection/conference is not created in that timespan, this CL adds logic to manually unbind the ConnectionService at that point in time. This prevents malicious apps from keeping a declared permission in forever even in the background. This updated CL includes a null check to avoid a NPE that occurred in the first iteration of this change.

Bug: 293458004
Test: manually using the provided apk + atest CallsManagerTest
Flag: EXEMPT Security High/Critical Severity CVE
Change-Id: If46cfa26278f09854c10266af6eaa73382f20296
(cherry picked from commit 48d6b0df91fb1c8c0b11891f878084c6d8f9fa8a)
Merged-In: If46cfa26278f09854c10266af6eaa73382f20296
diff --git a/src/com/android/server/telecom/Call.java b/src/com/android/server/telecom/Call.java
index 60016fd..d049985 100644
--- a/src/com/android/server/telecom/Call.java
+++ b/src/com/android/server/telecom/Call.java
@@ -353,6 +353,17 @@
     /** The state of the call. */
     private int mState;
 
+    /**
+     * Determines whether the {@link ConnectionService} has responded to the initial request to
+     * create the connection.
+     *
+     * {@code false} indicates the {@link Call} has been added to Telecom, but the
+     * {@link Connection} has not yet been returned by the associated {@link ConnectionService}.
+     * {@code true} indicates the {@link Call} has an associated {@link Connection} reported by the
+     * {@link ConnectionService}.
+     */
+    private boolean mIsCreateConnectionComplete = false;
+
     /** The handle with which to establish this call. */
     private Uri mHandle;
 
@@ -1041,6 +1052,19 @@
         return mConnectionService;
     }
 
+    /**
+     * @return {@code true} if the connection has been created by the underlying
+     * {@link ConnectionService}, {@code false} otherwise.
+     */
+    public boolean isCreateConnectionComplete() {
+        return mIsCreateConnectionComplete;
+    }
+
+    @VisibleForTesting
+    public void setIsCreateConnectionComplete(boolean isCreateConnectionComplete) {
+        mIsCreateConnectionComplete = isCreateConnectionComplete;
+    }
+
     @VisibleForTesting
     public int getState() {
         return mState;
@@ -2192,6 +2216,7 @@
             CallIdMapper idMapper,
             ParcelableConference conference) {
         Log.v(this, "handleCreateConferenceSuccessful %s", conference);
+        mIsCreateConnectionComplete = true;
         setTargetPhoneAccount(conference.getPhoneAccount());
         setHandle(conference.getHandle(), conference.getHandlePresentation());
 
@@ -2225,6 +2250,7 @@
             CallIdMapper idMapper,
             ParcelableConnection connection) {
         Log.v(this, "handleCreateConnectionSuccessful %s", connection);
+        mIsCreateConnectionComplete = true;
         setTargetPhoneAccount(connection.getPhoneAccount());
         setHandle(connection.getHandle(), connection.getHandlePresentation());
         setCallerDisplayName(
diff --git a/src/com/android/server/telecom/ConnectionServiceWrapper.java b/src/com/android/server/telecom/ConnectionServiceWrapper.java
index f2fa7a5..d30216a 100755
--- a/src/com/android/server/telecom/ConnectionServiceWrapper.java
+++ b/src/com/android/server/telecom/ConnectionServiceWrapper.java
@@ -40,6 +40,7 @@
 import android.telecom.GatewayInfo;
 import android.telecom.Log;
 import android.telecom.Logging.Session;
+import android.telecom.Logging.Runnable;
 import android.telecom.ParcelableConference;
 import android.telecom.ParcelableConnection;
 import android.telecom.PhoneAccountHandle;
@@ -63,6 +64,11 @@
 import java.util.Map;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.TimeUnit;
 import java.util.Objects;
 
 /**
@@ -76,6 +82,11 @@
         ConnectionServiceFocusManager.ConnectionServiceFocus {
 
     private static final String TELECOM_ABBREVIATION = "cast";
+    private static final long SERVICE_BINDING_TIMEOUT = 15000L;
+    private ScheduledExecutorService mScheduledExecutor =
+            Executors.newSingleThreadScheduledExecutor();
+    // Pre-allocate space for 2 calls; realistically thats all we should ever need (tm)
+    private final Map<Call, ScheduledFuture<?>> mScheduledFutureMap = new ConcurrentHashMap<>(2);
 
     private final class Adapter extends IConnectionServiceAdapter.Stub {
 
@@ -89,6 +100,12 @@
             try {
                 synchronized (mLock) {
                     logIncoming("handleCreateConnectionComplete %s", callId);
+                    Call call = mCallIdMapper.getCall(callId);
+                    if (call != null && mScheduledFutureMap.containsKey(call)) {
+                        ScheduledFuture<?> existingTimeout = mScheduledFutureMap.get(call);
+                        existingTimeout.cancel(false /* cancelIfRunning */);
+                        mScheduledFutureMap.remove(call);
+                    }
                     // Check status hints image for cross user access
                     if (connection.getStatusHints() != null) {
                         Icon icon = connection.getStatusHints().getIcon();
@@ -127,6 +144,12 @@
             try {
                 synchronized (mLock) {
                     logIncoming("handleCreateConferenceComplete %s", callId);
+                    Call call = mCallIdMapper.getCall(callId);
+                    if (call != null && mScheduledFutureMap.containsKey(call)) {
+                        ScheduledFuture<?> existingTimeout = mScheduledFutureMap.get(call);
+                        existingTimeout.cancel(false /* cancelIfRunning */);
+                        mScheduledFutureMap.remove(call);
+                    }
                     // Check status hints image for cross user access
                     if (conference.getStatusHints() != null) {
                         Icon icon = conference.getStatusHints().getIcon();
@@ -1260,7 +1283,8 @@
      * @param context The context.
      * @param userHandle The {@link UserHandle} to use when binding.
      */
-    ConnectionServiceWrapper(
+    @VisibleForTesting
+    public ConnectionServiceWrapper(
             ComponentName componentName,
             ConnectionServiceRepository connectionServiceRepository,
             PhoneAccountRegistrar phoneAccountRegistrar,
@@ -1356,7 +1380,26 @@
                         .setParticipants(call.getParticipants())
                         .setIsAdhocConferenceCall(call.isAdhocConferenceCall())
                         .build();
-
+                Runnable r = new Runnable("CSW.cC", mLock) {
+                    @Override
+                    public void loggedRun() {
+                        if (!call.isCreateConnectionComplete()) {
+                            Log.e(this, new Exception(),
+                                    "Conference %s creation timeout",
+                                    getComponentName());
+                            Log.addEvent(call, LogUtils.Events.CREATE_CONFERENCE_TIMEOUT,
+                                    Log.piiHandle(call.getHandle()) + " via:" +
+                                            getComponentName().getPackageName());
+                            response.handleCreateConferenceFailure(
+                                    new DisconnectCause(DisconnectCause.ERROR));
+                        }
+                    }
+                };
+                // Post cleanup to the executor service and cache the future, so we can cancel it if
+                // needed.
+                ScheduledFuture<?> future = mScheduledExecutor.schedule(r.getRunnableToCancel(),
+                        SERVICE_BINDING_TIMEOUT, TimeUnit.MILLISECONDS);
+                mScheduledFutureMap.put(call, future);
                 try {
                     mServiceInterface.createConference(
                             call.getConnectionManagerPhoneAccount(),
@@ -1458,7 +1501,26 @@
                         .setRttPipeFromInCall(call.getInCallToCsRttPipeForCs())
                         .setRttPipeToInCall(call.getCsToInCallRttPipeForCs())
                         .build();
-
+                Runnable r = new Runnable("CSW.cC", mLock) {
+                    @Override
+                    public void loggedRun() {
+                        if (!call.isCreateConnectionComplete()) {
+                            Log.e(this, new Exception(),
+                                    "Connection %s creation timeout",
+                                    getComponentName());
+                            Log.addEvent(call, LogUtils.Events.CREATE_CONNECTION_TIMEOUT,
+                                    Log.piiHandle(call.getHandle()) + " via:" +
+                                            getComponentName().getPackageName());
+                            response.handleCreateConnectionFailure(
+                                    new DisconnectCause(DisconnectCause.ERROR));
+                        }
+                    }
+                };
+                // Post cleanup to the executor service and cache the future, so we can cancel it if
+                // needed.
+                ScheduledFuture<?> future = mScheduledExecutor.schedule(r.getRunnableToCancel(),
+                        SERVICE_BINDING_TIMEOUT, TimeUnit.MILLISECONDS);
+                mScheduledFutureMap.put(call, future);
                 try {
                     mServiceInterface.createConnection(
                             call.getConnectionManagerPhoneAccount(),
@@ -1868,7 +1930,8 @@
         }
     }
 
-    void addCall(Call call) {
+    @VisibleForTesting
+    public void addCall(Call call) {
         if (mCallIdMapper.getCallId(call) == null) {
             mCallIdMapper.addCall(call);
         }
@@ -2084,6 +2147,13 @@
     @Override
     protected void removeServiceInterface() {
         Log.v(this, "Removing Connection Service Adapter.");
+        if (mServiceInterface == null) {
+            // In some cases, we may receive multiple calls to
+            // remoteServiceInterface, such as when the remote process crashes
+            // (onBinderDied & onServiceDisconnected)
+            Log.w(this, "removeServiceInterface: mServiceInterface is null");
+            return;
+        }
         removeConnectionServiceAdapter(mAdapter);
         // We have lost our service connection. Notify the world that this service is done.
         // We must notify the adapter before CallsManager. The adapter will force any pending
@@ -2092,6 +2162,10 @@
         handleConnectionServiceDeath();
         mCallsManager.handleConnectionServiceDeath(this);
         mServiceInterface = null;
+        if (mScheduledExecutor != null) {
+            mScheduledExecutor.shutdown();
+            mScheduledExecutor = null;
+        }
     }
 
     @Override
@@ -2210,6 +2284,7 @@
             }
         }
         mCallIdMapper.clear();
+        mScheduledFutureMap.clear();
 
         if (mConnSvrFocusListener != null) {
             mConnSvrFocusListener.onConnectionServiceDeath(this);
@@ -2335,4 +2410,9 @@
         sb.append("]");
         return sb.toString();
     }
+
+    @VisibleForTesting
+    public void setScheduledExecutorService(ScheduledExecutorService service) {
+        mScheduledExecutor = service;
+    }
 }
diff --git a/src/com/android/server/telecom/LogUtils.java b/src/com/android/server/telecom/LogUtils.java
index 12e780f..aafc383 100644
--- a/src/com/android/server/telecom/LogUtils.java
+++ b/src/com/android/server/telecom/LogUtils.java
@@ -131,8 +131,10 @@
         public static final String STOP_CALL_WAITING_TONE = "STOP_CALL_WAITING_TONE";
         public static final String START_CONNECTION = "START_CONNECTION";
         public static final String CREATE_CONNECTION_FAILED = "CREATE_CONNECTION_FAILED";
+        public static final String CREATE_CONNECTION_TIMEOUT = "CREATE_CONNECTION_TIMEOUT";
         public static final String START_CONFERENCE = "START_CONFERENCE";
         public static final String CREATE_CONFERENCE_FAILED = "CREATE_CONFERENCE_FAILED";
+        public static final String CREATE_CONFERENCE_TIMEOUT = "CREATE_CONFERENCE_TIMEOUT";
         public static final String BIND_CS = "BIND_CS";
         public static final String CS_BOUND = "CS_BOUND";
         public static final String CONFERENCE_WITH = "CONF_WITH";
diff --git a/tests/src/com/android/server/telecom/tests/BasicCallTests.java b/tests/src/com/android/server/telecom/tests/BasicCallTests.java
index 7bce5a6..f5e6f42 100644
--- a/tests/src/com/android/server/telecom/tests/BasicCallTests.java
+++ b/tests/src/com/android/server/telecom/tests/BasicCallTests.java
@@ -993,6 +993,7 @@
         call.setTargetPhoneAccount(mPhoneAccountA1.getAccountHandle());
         assert(call.isVideoCallingSupportedByPhoneAccount());
         assertEquals(VideoProfile.STATE_BIDIRECTIONAL, call.getVideoState());
+        call.setIsCreateConnectionComplete(true);
     }
 
     /**
@@ -1016,6 +1017,7 @@
         call.setTargetPhoneAccount(mPhoneAccountA2.getAccountHandle());
         assert(!call.isVideoCallingSupportedByPhoneAccount());
         assertEquals(VideoProfile.STATE_AUDIO_ONLY, call.getVideoState());
+        call.setIsCreateConnectionComplete(true);
     }
 
     /**
diff --git a/tests/src/com/android/server/telecom/tests/CallsManagerTest.java b/tests/src/com/android/server/telecom/tests/CallsManagerTest.java
index e7d0a26..673e428 100644
--- a/tests/src/com/android/server/telecom/tests/CallsManagerTest.java
+++ b/tests/src/com/android/server/telecom/tests/CallsManagerTest.java
@@ -41,6 +41,7 @@
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
+import static java.lang.Thread.sleep;
 
 import android.content.ComponentName;
 import android.content.ContentResolver;
@@ -49,6 +50,7 @@
 import android.net.Uri;
 import android.os.Bundle;
 import android.os.Handler;
+import android.os.IBinder;
 import android.os.Looper;
 import android.os.Process;
 import android.os.SystemClock;
@@ -67,6 +69,7 @@
 import android.util.Pair;
 import android.widget.Toast;
 
+import com.android.internal.telecom.IConnectionService;
 import com.android.server.telecom.AsyncRingtonePlayer;
 import com.android.server.telecom.Call;
 import com.android.server.telecom.CallAudioManager;
@@ -80,6 +83,7 @@
 import com.android.server.telecom.ConnectionServiceFocusManager;
 import com.android.server.telecom.ConnectionServiceFocusManager.ConnectionServiceFocusManagerFactory;
 import com.android.server.telecom.ConnectionServiceWrapper;
+import com.android.server.telecom.CreateConnectionResponse;
 import com.android.server.telecom.DefaultDialerCache;
 import com.android.server.telecom.EmergencyCallHelper;
 import com.android.server.telecom.HeadsetMediaButton;
@@ -211,6 +215,7 @@
     @Mock private RoleManagerAdapter mRoleManagerAdapter;
     @Mock private ToastFactory mToastFactory;
     @Mock private Toast mToast;
+    @Mock private IConnectionService mIConnectionService;
 
     private CallsManager mCallsManager;
 
@@ -278,11 +283,17 @@
                 eq(SIM_2_HANDLE), any())).thenReturn(SIM_2_ACCOUNT);
         when(mToastFactory.makeText(any(), anyInt(), anyInt())).thenReturn(mToast);
         when(mToastFactory.makeText(any(), any(), anyInt())).thenReturn(mToast);
+        when(mIConnectionService.asBinder()).thenReturn(mock(IBinder.class));
+
+        mComponentContextFixture.addConnectionService(
+                SIM_1_ACCOUNT.getAccountHandle().getComponentName(), mIConnectionService);
     }
 
     @Override
     @After
     public void tearDown() throws Exception {
+        mComponentContextFixture.removeConnectionService(
+                SIM_1_ACCOUNT.getAccountHandle().getComponentName(), mIConnectionService);
         super.tearDown();
     }
 
@@ -1695,6 +1706,35 @@
         assertTrue(argumentCaptor.getValue().contains("Unavailable phoneAccountHandle"));
     }
 
+    @Test
+    public void testConnectionServiceCreateConnectionTimeout() throws Exception {
+        ConnectionServiceWrapper service = new ConnectionServiceWrapper(
+                SIM_1_ACCOUNT.getAccountHandle().getComponentName(), null,
+                mPhoneAccountRegistrar, mCallsManager, mContext, mLock, null);
+        TestScheduledExecutorService scheduledExecutorService = new TestScheduledExecutorService();
+        service.setScheduledExecutorService(scheduledExecutorService);
+        Call call = addSpyCall();
+        service.addCall(call);
+        when(call.isCreateConnectionComplete()).thenReturn(false);
+        CreateConnectionResponse response = mock(CreateConnectionResponse.class);
+
+        service.createConnection(call, response);
+        waitUntilConditionIsTrueOrTimeout(new Condition() {
+            @Override
+            public Object expected() {
+                return true;
+            }
+
+            @Override
+            public Object actual() {
+                return scheduledExecutorService.isRunnableScheduledAtTime(15000L);
+            }
+        }, 5000L, "Expected job failed to schedule");
+        scheduledExecutorService.advanceTime(15000L);
+        verify(response).handleCreateConnectionFailure(
+                eq(new DisconnectCause(DisconnectCause.ERROR)));
+    }
+
     private Call addSpyCall() {
         return addSpyCall(SIM_2_HANDLE, CallState.ACTIVE);
     }
@@ -1787,4 +1827,19 @@
         when(mPhoneAccountRegistrar.getSimPhoneAccountsOfCurrentUser()).thenReturn(
                 new ArrayList<>(Arrays.asList(SIM_1_HANDLE, SIM_2_HANDLE)));
     }
+
+    private void waitUntilConditionIsTrueOrTimeout(Condition condition, long timeout,
+            String description) throws InterruptedException {
+        final long start = System.currentTimeMillis();
+        while (!condition.expected().equals(condition.actual())
+                && System.currentTimeMillis() - start < timeout) {
+            sleep(50);
+        }
+        assertEquals(description, condition.expected(), condition.actual());
+    }
+
+    protected interface Condition {
+        Object expected();
+        Object actual();
+    }
 }
diff --git a/tests/src/com/android/server/telecom/tests/ComponentContextFixture.java b/tests/src/com/android/server/telecom/tests/ComponentContextFixture.java
index 477ca9f..77a182d 100644
--- a/tests/src/com/android/server/telecom/tests/ComponentContextFixture.java
+++ b/tests/src/com/android/server/telecom/tests/ComponentContextFixture.java
@@ -692,6 +692,14 @@
         mServiceInfoByComponentName.put(componentName, serviceInfo);
     }
 
+    public void removeConnectionService(
+            ComponentName componentName,
+            IConnectionService service)
+            throws Exception {
+        removeService(ConnectionService.SERVICE_INTERFACE, componentName, service);
+        mServiceInfoByComponentName.remove(componentName);
+    }
+
     public void addInCallService(
             ComponentName componentName,
             IInCallService service,
@@ -777,6 +785,12 @@
         mComponentNameByService.put(service, name);
     }
 
+    private void removeService(String action, ComponentName name, IInterface service) {
+        mComponentNamesByAction.remove(action, name);
+        mServiceByComponentName.remove(name);
+        mComponentNameByService.remove(service);
+    }
+
     private List<ResolveInfo> doQueryIntentServices(Intent intent, int flags) {
         List<ResolveInfo> result = new ArrayList<>();
         for (ComponentName componentName : mComponentNamesByAction.get(intent.getAction())) {
diff --git a/tests/src/com/android/server/telecom/tests/TestScheduledExecutorService.java b/tests/src/com/android/server/telecom/tests/TestScheduledExecutorService.java
new file mode 100644
index 0000000..8ddf42b
--- /dev/null
+++ b/tests/src/com/android/server/telecom/tests/TestScheduledExecutorService.java
@@ -0,0 +1,283 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.telecom.tests;
+
+import android.util.Log;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Delayed;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+/**
+ * A test implementation of a scheduled executor service.
+ */
+public class TestScheduledExecutorService implements ScheduledExecutorService {
+    private static final String TAG = "TestScheduledExecutorService";
+
+    private class CompletedFuture<T> implements Future<T>, ScheduledFuture<T> {
+
+        private final Callable<T> mTask;
+        private final long mDelayMs;
+        private Runnable mRunnable;
+
+        CompletedFuture(Callable<T> task) {
+            mTask = task;
+            mDelayMs = 0;
+        }
+
+        @SuppressWarnings("unused")
+        CompletedFuture(Callable<T> task, long delayMs) {
+            mTask = task;
+            mDelayMs = delayMs;
+        }
+
+        CompletedFuture(Runnable task, long delayMs) {
+            mRunnable = task;
+            mTask = (Callable<T>) Executors.callable(task);
+            mDelayMs = delayMs;
+        }
+
+        @Override
+        public boolean cancel(boolean mayInterruptIfRunning) {
+            cancelRunnable(mRunnable);
+            return true;
+        }
+
+        @Override
+        public boolean isCancelled() {
+            return false;
+        }
+
+        @Override
+        public boolean isDone() {
+            return true;
+        }
+
+        @Override
+        public T get() throws InterruptedException, ExecutionException {
+            try {
+                return mTask.call();
+            } catch (Exception e) {
+                throw new ExecutionException(e);
+            }
+        }
+
+        @Override
+        public T get(long timeout, TimeUnit unit)
+                throws InterruptedException, ExecutionException, TimeoutException {
+            try {
+                return mTask.call();
+            } catch (Exception e) {
+                throw new ExecutionException(e);
+            }
+        }
+
+        @Override
+        public long getDelay(TimeUnit unit) {
+            if (unit == TimeUnit.MILLISECONDS) {
+                return mDelayMs;
+            } else {
+                // not implemented
+                return 0;
+            }
+        }
+
+        @Override
+        public int compareTo(Delayed o) {
+            if (o == null) return 1;
+            if (o.getDelay(TimeUnit.MILLISECONDS) > mDelayMs) return -1;
+            if (o.getDelay(TimeUnit.MILLISECONDS) < mDelayMs) return 1;
+            return 0;
+        }
+    }
+
+    private long mClock = 0;
+    private Map<Long, Runnable> mScheduledRunnables = new HashMap<>();
+    private Map<Runnable, Long> mRepeatDuration = new HashMap<>();
+
+    @Override
+    public void shutdown() {
+    }
+
+    @Override
+    public List<Runnable> shutdownNow() {
+        return null;
+    }
+
+    @Override
+    public boolean isShutdown() {
+        return false;
+    }
+
+    @Override
+    public boolean isTerminated() {
+        return false;
+    }
+
+    @Override
+    public boolean awaitTermination(long timeout, TimeUnit unit) {
+        return false;
+    }
+
+    @Override
+    public <T> Future<T> submit(Callable<T> task) {
+        return new TestScheduledExecutorService.CompletedFuture<>(task);
+    }
+
+    @Override
+    public <T> Future<T> submit(Runnable task, T result) {
+        throw new UnsupportedOperationException("Not implemented");
+    }
+
+    @Override
+    public Future<?> submit(Runnable task) {
+        task.run();
+        return new TestScheduledExecutorService.CompletedFuture<>(() -> null);
+    }
+
+    @Override
+    public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) {
+        throw new UnsupportedOperationException("Not implemented");
+    }
+
+    @Override
+    public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout,
+                                         TimeUnit unit) {
+        throw new UnsupportedOperationException("Not implemented");
+    }
+
+    @Override
+    public <T> T invokeAny(Collection<? extends Callable<T>> tasks) {
+        throw new UnsupportedOperationException("Not implemented");
+    }
+
+    @Override
+    public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit) {
+        throw new UnsupportedOperationException("Not implemented");
+    }
+
+    @Override
+    public ScheduledFuture<?> schedule(Runnable command, long delay, TimeUnit unit) {
+        // Schedule the runnable for execution at the specified time.
+        long scheduledTime = getNextExecutionTime(delay, unit);
+        mScheduledRunnables.put(scheduledTime, command);
+
+        Log.i(TAG, "schedule: runnable=" + System.identityHashCode(command) + ", time="
+                + scheduledTime);
+
+        return new TestScheduledExecutorService.CompletedFuture<Runnable>(command, delay);
+    }
+
+    @Override
+    public <V> ScheduledFuture<V> schedule(Callable<V> callable, long delay, TimeUnit unit) {
+        throw new UnsupportedOperationException("Not implemented");
+    }
+
+    @Override
+    public ScheduledFuture<?> scheduleAtFixedRate(Runnable command, long initialDelay, long period,
+                                                  TimeUnit unit) {
+        return scheduleWithFixedDelay(command, initialDelay, period, unit);
+    }
+
+    @Override
+    public ScheduledFuture<?> scheduleWithFixedDelay(Runnable command, long initialDelay,
+                                                     long delay, TimeUnit unit) {
+        // Schedule the runnable for execution at the specified time.
+        long nextScheduledTime = getNextExecutionTime(delay, unit);
+        mScheduledRunnables.put(nextScheduledTime, command);
+        mRepeatDuration.put(command, unit.toMillis(delay));
+
+        return new TestScheduledExecutorService.CompletedFuture<Runnable>(command, delay);
+    }
+
+    private long getNextExecutionTime(long delay, TimeUnit unit) {
+        long delayMillis = unit.toMillis(delay);
+        return mClock + delayMillis;
+    }
+
+    @Override
+    public void execute(Runnable command) {
+        command.run();
+    }
+
+    /**
+     * Used in unit tests, used to add a delta to the "clock" so that we can fire off scheduled
+     * items and reschedule the repeats.
+     * @param duration The duration (millis) to add to the clock.
+     */
+    public void advanceTime(long duration) {
+        Map<Long, Runnable> nextRepeats = new HashMap<>();
+        List<Runnable> toRun = new ArrayList<>();
+        mClock += duration;
+        Iterator<Map.Entry<Long, Runnable>> iterator = mScheduledRunnables.entrySet().iterator();
+        while (iterator.hasNext()) {
+            Map.Entry<Long, Runnable> entry = iterator.next();
+            if (mClock >= entry.getKey()) {
+                toRun.add(entry.getValue());
+
+                Runnable r = entry.getValue();
+                Log.i(TAG, "advanceTime: runningRunnable=" + System.identityHashCode(r));
+                // If this is a repeating scheduled item, schedule the repeat.
+                if (mRepeatDuration.containsKey(r)) {
+                    // schedule next execution
+                    nextRepeats.put(mClock + mRepeatDuration.get(r), entry.getValue());
+                }
+                iterator.remove();
+            }
+        }
+
+        // Update things at the end to avoid concurrent access.
+        mScheduledRunnables.putAll(nextRepeats);
+        toRun.forEach(r -> r.run());
+    }
+
+    /**
+     * Used from a {@link CompletedFuture} as defined above to cancel a scheduled task.
+     * @param r The runnable to cancel.
+     */
+    private void cancelRunnable(Runnable r) {
+        Optional<Map.Entry<Long, Runnable>> found = mScheduledRunnables.entrySet().stream()
+                .filter(e -> e.getValue() == r)
+                .findFirst();
+        if (found.isPresent()) {
+            mScheduledRunnables.remove(found.get().getKey());
+        }
+        mRepeatDuration.remove(r);
+        Log.i(TAG, "cancelRunnable: runnable=" + System.identityHashCode(r));
+    }
+
+    public int getNumberOfScheduledRunnables() {
+        return mScheduledRunnables.size();
+    }
+
+    public boolean isRunnableScheduledAtTime(long time) {
+        return mScheduledRunnables.containsKey(time);
+    }
+}
\ No newline at end of file