Merge "Add callback registration in ConnectivityService."
diff --git a/core/java/android/net/ConnectivityDiagnosticsManager.java b/core/java/android/net/ConnectivityDiagnosticsManager.java
index b13e4b7..d018cbd 100644
--- a/core/java/android/net/ConnectivityDiagnosticsManager.java
+++ b/core/java/android/net/ConnectivityDiagnosticsManager.java
@@ -25,13 +25,16 @@
 import android.os.Parcel;
 import android.os.Parcelable;
 import android.os.PersistableBundle;
+import android.os.RemoteException;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.Preconditions;
 
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
+import java.util.Map;
 import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.Executor;
 
 /**
@@ -57,6 +60,11 @@
  * </ul>
  */
 public class ConnectivityDiagnosticsManager {
+    /** @hide */
+    @VisibleForTesting
+    public static final Map<ConnectivityDiagnosticsCallback, ConnectivityDiagnosticsBinder>
+            sCallbacks = new ConcurrentHashMap<>();
+
     private final Context mContext;
     private final IConnectivityManager mService;
 
@@ -646,9 +654,9 @@
      * <p>If a registering app loses its relevant permissions, any callbacks it registered will
      * silently stop receiving callbacks.
      *
-     * <p>Each register() call <b>MUST</b> use a unique ConnectivityDiagnosticsCallback instance. If
-     * a single instance is registered with multiple NetworkRequests, an IllegalArgumentException
-     * will be thrown.
+     * <p>Each register() call <b>MUST</b> use a ConnectivityDiagnosticsCallback instance that is
+     * not currently registered. If a ConnectivityDiagnosticsCallback instance is registered with
+     * multiple NetworkRequests, an IllegalArgumentException will be thrown.
      *
      * @param request The NetworkRequest that will be used to match with Networks for which
      *     callbacks will be fired
@@ -657,15 +665,21 @@
      *     System
      * @throws IllegalArgumentException if the same callback instance is registered with multiple
      *     NetworkRequests
-     * @throws SecurityException if the caller does not have appropriate permissions to register a
-     *     callback
      */
     public void registerConnectivityDiagnosticsCallback(
             @NonNull NetworkRequest request,
             @NonNull Executor e,
             @NonNull ConnectivityDiagnosticsCallback callback) {
-        // TODO(b/143187964): implement ConnectivityDiagnostics functionality
-        throw new UnsupportedOperationException("registerCallback() not supported yet");
+        final ConnectivityDiagnosticsBinder binder = new ConnectivityDiagnosticsBinder(callback, e);
+        if (sCallbacks.putIfAbsent(callback, binder) != null) {
+            throw new IllegalArgumentException("Callback is currently registered");
+        }
+
+        try {
+            mService.registerConnectivityDiagnosticsCallback(binder, request);
+        } catch (RemoteException exception) {
+            exception.rethrowFromSystemServer();
+        }
     }
 
     /**
@@ -678,7 +692,15 @@
      */
     public void unregisterConnectivityDiagnosticsCallback(
             @NonNull ConnectivityDiagnosticsCallback callback) {
-        // TODO(b/143187964): implement ConnectivityDiagnostics functionality
-        throw new UnsupportedOperationException("registerCallback() not supported yet");
+        // unconditionally removing from sCallbacks prevents race conditions here, since remove() is
+        // atomic.
+        final ConnectivityDiagnosticsBinder binder = sCallbacks.remove(callback);
+        if (binder == null) return;
+
+        try {
+            mService.unregisterConnectivityDiagnosticsCallback(binder);
+        } catch (RemoteException exception) {
+            exception.rethrowFromSystemServer();
+        }
     }
 }
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index e98c370..06f8fb1 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -559,13 +559,17 @@
                 .asInterface(ServiceManager.getService("dnsresolver"));
     }
 
-    /** Handler thread used for both of the handlers below. */
+    /** Handler thread used for all of the handlers below. */
     @VisibleForTesting
     protected final HandlerThread mHandlerThread;
     /** Handler used for internal events. */
     final private InternalHandler mHandler;
     /** Handler used for incoming {@link NetworkStateTracker} events. */
     final private NetworkStateTrackerHandler mTrackerHandler;
+    /** Handler used for processing {@link android.net.ConnectivityDiagnosticsManager} events */
+    @VisibleForTesting
+    final ConnectivityDiagnosticsHandler mConnectivityDiagnosticsHandler;
+
     private final DnsManager mDnsManager;
 
     private boolean mSystemReady;
@@ -632,6 +636,10 @@
     @VisibleForTesting
     final MultipathPolicyTracker mMultipathPolicyTracker;
 
+    @VisibleForTesting
+    final Map<IConnectivityDiagnosticsCallback, ConnectivityDiagnosticsCallbackInfo>
+            mConnectivityDiagnosticsCallbacks = new HashMap<>();
+
     /**
      * Implements support for the legacy "one network per network type" model.
      *
@@ -964,6 +972,8 @@
         mHandlerThread.start();
         mHandler = new InternalHandler(mHandlerThread.getLooper());
         mTrackerHandler = new NetworkStateTrackerHandler(mHandlerThread.getLooper());
+        mConnectivityDiagnosticsHandler =
+                new ConnectivityDiagnosticsHandler(mHandlerThread.getLooper());
 
         mReleasePendingIntentDelayMs = Settings.Secure.getInt(context.getContentResolver(),
                 Settings.Secure.CONNECTIVITY_RELEASE_PENDING_INTENT_DELAY_MS, 5_000);
@@ -3384,18 +3394,7 @@
         nri.unlinkDeathRecipient();
         mNetworkRequests.remove(nri.request);
 
-        synchronized (mUidToNetworkRequestCount) {
-            int requests = mUidToNetworkRequestCount.get(nri.mUid, 0);
-            if (requests < 1) {
-                Slog.wtf(TAG, "BUG: too small request count " + requests + " for UID " +
-                        nri.mUid);
-            } else if (requests == 1) {
-                mUidToNetworkRequestCount.removeAt(
-                        mUidToNetworkRequestCount.indexOfKey(nri.mUid));
-            } else {
-                mUidToNetworkRequestCount.put(nri.mUid, requests - 1);
-            }
-        }
+        decrementNetworkRequestPerUidCount(nri);
 
         mNetworkRequestInfoLogs.log("RELEASE " + nri);
         if (nri.request.isRequest()) {
@@ -3466,6 +3465,19 @@
         }
     }
 
+    private void decrementNetworkRequestPerUidCount(final NetworkRequestInfo nri) {
+        synchronized (mUidToNetworkRequestCount) {
+            final int requests = mUidToNetworkRequestCount.get(nri.mUid, 0);
+            if (requests < 1) {
+                Slog.wtf(TAG, "BUG: too small request count " + requests + " for UID " + nri.mUid);
+            } else if (requests == 1) {
+                mUidToNetworkRequestCount.removeAt(mUidToNetworkRequestCount.indexOfKey(nri.mUid));
+            } else {
+                mUidToNetworkRequestCount.put(nri.mUid, requests - 1);
+            }
+        }
+    }
+
     @Override
     public void setAcceptUnvalidated(Network network, boolean accept, boolean always) {
         enforceNetworkStackSettingsOrSetup();
@@ -5084,6 +5096,10 @@
             }
         }
 
+        NetworkRequestInfo(NetworkRequest r) {
+            this(r, null);
+        }
+
         private void enforceRequestCountLimit() {
             synchronized (mUidToNetworkRequestCount) {
                 int networkRequests = mUidToNetworkRequestCount.get(mUid, 0) + 1;
@@ -6174,7 +6190,10 @@
     private void callCallbackForRequest(NetworkRequestInfo nri,
             NetworkAgentInfo networkAgent, int notificationType, int arg1) {
         if (nri.messenger == null) {
-            return;  // Default request has no msgr
+            // Default request has no msgr. Also prevents callbacks from being invoked for
+            // NetworkRequestInfos registered with ConnectivityDiagnostics requests. Those callbacks
+            // are Type.LISTEN, but should not have NetworkCallbacks invoked.
+            return;
         }
         Bundle bundle = new Bundle();
         // TODO: check if defensive copies of data is needed.
@@ -7330,19 +7349,161 @@
         }
     }
 
+    /**
+     * Handler used for managing all Connectivity Diagnostics related functions.
+     *
+     * @see android.net.ConnectivityDiagnosticsManager
+     *
+     * TODO(b/147816404): Explore moving ConnectivityDiagnosticsHandler to a separate file
+     */
+    @VisibleForTesting
+    class ConnectivityDiagnosticsHandler extends Handler {
+        /**
+         * Used to handle ConnectivityDiagnosticsCallback registration events from {@link
+         * android.net.ConnectivityDiagnosticsManager}.
+         * obj = ConnectivityDiagnosticsCallbackInfo with IConnectivityDiagnosticsCallback and
+         * NetworkRequestInfo to be registered
+         */
+        private static final int EVENT_REGISTER_CONNECTIVITY_DIAGNOSTICS_CALLBACK = 1;
+
+        /**
+         * Used to handle ConnectivityDiagnosticsCallback unregister events from {@link
+         * android.net.ConnectivityDiagnosticsManager}.
+         * obj = the IConnectivityDiagnosticsCallback to be unregistered
+         * arg1 = the uid of the caller
+         */
+        private static final int EVENT_UNREGISTER_CONNECTIVITY_DIAGNOSTICS_CALLBACK = 2;
+
+        private ConnectivityDiagnosticsHandler(Looper looper) {
+            super(looper);
+        }
+
+        @Override
+        public void handleMessage(Message msg) {
+            switch (msg.what) {
+                case EVENT_REGISTER_CONNECTIVITY_DIAGNOSTICS_CALLBACK: {
+                    handleRegisterConnectivityDiagnosticsCallback(
+                            (ConnectivityDiagnosticsCallbackInfo) msg.obj);
+                    break;
+                }
+                case EVENT_UNREGISTER_CONNECTIVITY_DIAGNOSTICS_CALLBACK: {
+                    handleUnregisterConnectivityDiagnosticsCallback(
+                            (IConnectivityDiagnosticsCallback) msg.obj, msg.arg1);
+                    break;
+                }
+            }
+        }
+    }
+
+    /** Class used for cleaning up IConnectivityDiagnosticsCallback instances after their death. */
+    @VisibleForTesting
+    class ConnectivityDiagnosticsCallbackInfo implements Binder.DeathRecipient {
+        @NonNull private final IConnectivityDiagnosticsCallback mCb;
+        @NonNull private final NetworkRequestInfo mRequestInfo;
+
+        @VisibleForTesting
+        ConnectivityDiagnosticsCallbackInfo(
+                @NonNull IConnectivityDiagnosticsCallback cb, @NonNull NetworkRequestInfo nri) {
+            mCb = cb;
+            mRequestInfo = nri;
+        }
+
+        @Override
+        public void binderDied() {
+            log("ConnectivityDiagnosticsCallback IBinder died.");
+            unregisterConnectivityDiagnosticsCallback(mCb);
+        }
+    }
+
+    private void handleRegisterConnectivityDiagnosticsCallback(
+            @NonNull ConnectivityDiagnosticsCallbackInfo cbInfo) {
+        ensureRunningOnConnectivityServiceThread();
+
+        final IConnectivityDiagnosticsCallback cb = cbInfo.mCb;
+        final NetworkRequestInfo nri = cbInfo.mRequestInfo;
+
+        // This means that the client registered the same callback multiple times. Do
+        // not override the previous entry, and exit silently.
+        if (mConnectivityDiagnosticsCallbacks.containsKey(cb)) {
+            if (VDBG) log("Diagnostics callback is already registered");
+
+            // Decrement the reference count for this NetworkRequestInfo. The reference count is
+            // incremented when the NetworkRequestInfo is created as part of
+            // enforceRequestCountLimit().
+            decrementNetworkRequestPerUidCount(nri);
+            return;
+        }
+
+        mConnectivityDiagnosticsCallbacks.put(cb, cbInfo);
+
+        try {
+            cb.asBinder().linkToDeath(cbInfo, 0);
+        } catch (RemoteException e) {
+            cbInfo.binderDied();
+        }
+    }
+
+    private void handleUnregisterConnectivityDiagnosticsCallback(
+            @NonNull IConnectivityDiagnosticsCallback cb, int uid) {
+        ensureRunningOnConnectivityServiceThread();
+
+        if (!mConnectivityDiagnosticsCallbacks.containsKey(cb)) {
+            if (VDBG) log("Removing diagnostics callback that is not currently registered");
+            return;
+        }
+
+        final NetworkRequestInfo nri = mConnectivityDiagnosticsCallbacks.get(cb).mRequestInfo;
+
+        if (uid != nri.mUid) {
+            if (VDBG) loge("Different uid than registrant attempting to unregister cb");
+            return;
+        }
+
+        cb.asBinder().unlinkToDeath(mConnectivityDiagnosticsCallbacks.remove(cb), 0);
+    }
+
     @Override
     public void registerConnectivityDiagnosticsCallback(
             @NonNull IConnectivityDiagnosticsCallback callback, @NonNull NetworkRequest request) {
-        // TODO(b/146444622): implement register IConnectivityDiagnosticsCallback functionality
-        throw new UnsupportedOperationException(
-                "registerConnectivityDiagnosticsCallback not yet implemented");
+        if (request.legacyType != TYPE_NONE) {
+            throw new IllegalArgumentException("ConnectivityManager.TYPE_* are deprecated."
+                    + " Please use NetworkCapabilities instead.");
+        }
+
+        // This NetworkCapabilities is only used for matching to Networks. Clear out its owner uid
+        // and administrator uids to be safe.
+        final NetworkCapabilities nc = new NetworkCapabilities(request.networkCapabilities);
+        restrictRequestUidsForCaller(nc);
+
+        final NetworkRequest requestWithId =
+                new NetworkRequest(
+                        nc, TYPE_NONE, nextNetworkRequestId(), NetworkRequest.Type.LISTEN);
+
+        // NetworkRequestInfos created here count towards MAX_NETWORK_REQUESTS_PER_UID limit.
+        //
+        // nri is not bound to the death of callback. Instead, callback.bindToDeath() is set in
+        // handleRegisterConnectivityDiagnosticsCallback(). nri will be cleaned up as part of the
+        // callback's binder death.
+        final NetworkRequestInfo nri = new NetworkRequestInfo(requestWithId);
+        final ConnectivityDiagnosticsCallbackInfo cbInfo =
+                new ConnectivityDiagnosticsCallbackInfo(callback, nri);
+
+        mConnectivityDiagnosticsHandler.sendMessage(
+                mConnectivityDiagnosticsHandler.obtainMessage(
+                        ConnectivityDiagnosticsHandler
+                                .EVENT_REGISTER_CONNECTIVITY_DIAGNOSTICS_CALLBACK,
+                        cbInfo));
     }
 
     @Override
     public void unregisterConnectivityDiagnosticsCallback(
             @NonNull IConnectivityDiagnosticsCallback callback) {
-        // TODO(b/146444622): implement register IConnectivityDiagnosticsCallback functionality
-        throw new UnsupportedOperationException(
-                "unregisterConnectivityDiagnosticsCallback not yet implemented");
+        mConnectivityDiagnosticsHandler.sendMessage(
+                mConnectivityDiagnosticsHandler.obtainMessage(
+                        ConnectivityDiagnosticsHandler
+                                .EVENT_UNREGISTER_CONNECTIVITY_DIAGNOSTICS_CALLBACK,
+                        Binder.getCallingUid(),
+                        0,
+                        callback));
     }
 }
diff --git a/tests/net/java/android/net/ConnectivityDiagnosticsManagerTest.java b/tests/net/java/android/net/ConnectivityDiagnosticsManagerTest.java
index 7ab4b56..2d5df4f 100644
--- a/tests/net/java/android/net/ConnectivityDiagnosticsManagerTest.java
+++ b/tests/net/java/android/net/ConnectivityDiagnosticsManagerTest.java
@@ -27,12 +27,18 @@
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
 
+import android.content.Context;
 import android.os.PersistableBundle;
 
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -52,15 +58,27 @@
 
     private static final Executor INLINE_EXECUTOR = x -> x.run();
 
+    @Mock private Context mContext;
+    @Mock private IConnectivityManager mService;
     @Mock private ConnectivityDiagnosticsCallback mCb;
 
     private ConnectivityDiagnosticsBinder mBinder;
+    private ConnectivityDiagnosticsManager mManager;
 
     @Before
     public void setUp() {
+        mContext = mock(Context.class);
+        mService = mock(IConnectivityManager.class);
         mCb = mock(ConnectivityDiagnosticsCallback.class);
 
         mBinder = new ConnectivityDiagnosticsBinder(mCb, INLINE_EXECUTOR);
+        mManager = new ConnectivityDiagnosticsManager(mContext, mService);
+    }
+
+    @After
+    public void tearDown() {
+        // clear ConnectivityDiagnosticsManager callbacks map
+        ConnectivityDiagnosticsManager.sCallbacks.clear();
     }
 
     private ConnectivityReport createSampleConnectivityReport() {
@@ -245,4 +263,53 @@
         // latch without waiting.
         verify(mCb).onNetworkConnectivityReported(eq(n), eq(connectivity));
     }
+
+    @Test
+    public void testRegisterConnectivityDiagnosticsCallback() throws Exception {
+        final NetworkRequest request = new NetworkRequest.Builder().build();
+
+        mManager.registerConnectivityDiagnosticsCallback(request, INLINE_EXECUTOR, mCb);
+
+        verify(mService).registerConnectivityDiagnosticsCallback(
+                any(ConnectivityDiagnosticsBinder.class), eq(request));
+        assertTrue(ConnectivityDiagnosticsManager.sCallbacks.containsKey(mCb));
+    }
+
+    @Test
+    public void testRegisterDuplicateConnectivityDiagnosticsCallback() throws Exception {
+        final NetworkRequest request = new NetworkRequest.Builder().build();
+
+        mManager.registerConnectivityDiagnosticsCallback(request, INLINE_EXECUTOR, mCb);
+
+        try {
+            mManager.registerConnectivityDiagnosticsCallback(request, INLINE_EXECUTOR, mCb);
+            fail("Duplicate callback registration should fail");
+        } catch (IllegalArgumentException expected) {
+        }
+    }
+
+    @Test
+    public void testUnregisterConnectivityDiagnosticsCallback() throws Exception {
+        final NetworkRequest request = new NetworkRequest.Builder().build();
+        mManager.registerConnectivityDiagnosticsCallback(request, INLINE_EXECUTOR, mCb);
+
+        mManager.unregisterConnectivityDiagnosticsCallback(mCb);
+
+        verify(mService).unregisterConnectivityDiagnosticsCallback(
+                any(ConnectivityDiagnosticsBinder.class));
+        assertFalse(ConnectivityDiagnosticsManager.sCallbacks.containsKey(mCb));
+
+        // verify that re-registering is successful
+        mManager.registerConnectivityDiagnosticsCallback(request, INLINE_EXECUTOR, mCb);
+        verify(mService, times(2)).registerConnectivityDiagnosticsCallback(
+                any(ConnectivityDiagnosticsBinder.class), eq(request));
+        assertTrue(ConnectivityDiagnosticsManager.sCallbacks.containsKey(mCb));
+    }
+
+    @Test
+    public void testUnregisterUnknownConnectivityDiagnosticsCallback() throws Exception {
+        mManager.unregisterConnectivityDiagnosticsCallback(mCb);
+
+        verifyNoMoreInteractions(mService);
+    }
 }
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index e80b7c9..50f1bbe 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -139,6 +139,7 @@
 import android.net.ConnectivityManager.PacketKeepaliveCallback;
 import android.net.ConnectivityManager.TooManyRequestsException;
 import android.net.ConnectivityThread;
+import android.net.IConnectivityDiagnosticsCallback;
 import android.net.IDnsResolver;
 import android.net.IIpConnectivityMetrics;
 import android.net.INetd;
@@ -180,6 +181,7 @@
 import android.os.ConditionVariable;
 import android.os.Handler;
 import android.os.HandlerThread;
+import android.os.IBinder;
 import android.os.INetworkManagementService;
 import android.os.Looper;
 import android.os.Parcel;
@@ -210,6 +212,7 @@
 import com.android.internal.util.WakeupMessage;
 import com.android.internal.util.test.BroadcastInterceptingContext;
 import com.android.internal.util.test.FakeSettingsProvider;
+import com.android.server.ConnectivityService.ConnectivityDiagnosticsCallbackInfo;
 import com.android.server.connectivity.ConnectivityConstants;
 import com.android.server.connectivity.DefaultNetworkMetrics;
 import com.android.server.connectivity.IpConnectivityMetrics;
@@ -322,6 +325,8 @@
     @Mock UserManager mUserManager;
     @Mock NotificationManager mNotificationManager;
     @Mock AlarmManager mAlarmManager;
+    @Mock IConnectivityDiagnosticsCallback mConnectivityDiagnosticsCallback;
+    @Mock IBinder mIBinder;
 
     private ArgumentCaptor<ResolverParamsParcel> mResolverParamsParcelCaptor =
             ArgumentCaptor.forClass(ResolverParamsParcel.class);
@@ -6355,4 +6360,70 @@
                 UserHandle.getAppId(uid));
         return packageInfo;
     }
+
+    @Test
+    public void testRegisterConnectivityDiagnosticsCallbackInvalidRequest() throws Exception {
+        final NetworkRequest request =
+                new NetworkRequest(
+                        new NetworkCapabilities(), TYPE_ETHERNET, 0, NetworkRequest.Type.NONE);
+        try {
+            mService.registerConnectivityDiagnosticsCallback(
+                    mConnectivityDiagnosticsCallback, request);
+            fail("registerConnectivityDiagnosticsCallback should throw on invalid NetworkRequest");
+        } catch (IllegalArgumentException expected) {
+        }
+    }
+
+    @Test
+    public void testRegisterUnregisterConnectivityDiagnosticsCallback() throws Exception {
+        final NetworkRequest wifiRequest =
+                new NetworkRequest.Builder().addTransportType(TRANSPORT_WIFI).build();
+
+        when(mConnectivityDiagnosticsCallback.asBinder()).thenReturn(mIBinder);
+
+        mService.registerConnectivityDiagnosticsCallback(
+                mConnectivityDiagnosticsCallback, wifiRequest);
+
+        verify(mIBinder, timeout(TIMEOUT_MS))
+                .linkToDeath(any(ConnectivityDiagnosticsCallbackInfo.class), anyInt());
+        assertTrue(
+                mService.mConnectivityDiagnosticsCallbacks.containsKey(
+                        mConnectivityDiagnosticsCallback));
+
+        mService.unregisterConnectivityDiagnosticsCallback(mConnectivityDiagnosticsCallback);
+        verify(mIBinder, timeout(TIMEOUT_MS))
+                .unlinkToDeath(any(ConnectivityDiagnosticsCallbackInfo.class), anyInt());
+        assertFalse(
+                mService.mConnectivityDiagnosticsCallbacks.containsKey(
+                        mConnectivityDiagnosticsCallback));
+        verify(mConnectivityDiagnosticsCallback, atLeastOnce()).asBinder();
+    }
+
+    @Test
+    public void testRegisterDuplicateConnectivityDiagnosticsCallback() throws Exception {
+        final NetworkRequest wifiRequest =
+                new NetworkRequest.Builder().addTransportType(TRANSPORT_WIFI).build();
+        when(mConnectivityDiagnosticsCallback.asBinder()).thenReturn(mIBinder);
+
+        mService.registerConnectivityDiagnosticsCallback(
+                mConnectivityDiagnosticsCallback, wifiRequest);
+
+        verify(mIBinder, timeout(TIMEOUT_MS))
+                .linkToDeath(any(ConnectivityDiagnosticsCallbackInfo.class), anyInt());
+        verify(mConnectivityDiagnosticsCallback).asBinder();
+        assertTrue(
+                mService.mConnectivityDiagnosticsCallbacks.containsKey(
+                        mConnectivityDiagnosticsCallback));
+
+        // Register the same callback again
+        mService.registerConnectivityDiagnosticsCallback(
+                mConnectivityDiagnosticsCallback, wifiRequest);
+
+        // Block until all other events are done processing.
+        HandlerUtilsKt.waitForIdle(mCsHandlerThread, TIMEOUT_MS);
+
+        assertTrue(
+                mService.mConnectivityDiagnosticsCallbacks.containsKey(
+                        mConnectivityDiagnosticsCallback));
+    }
 }