Merge "Use IBinder as key for ConnectivityDiagnostics storage in CS."
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index ec64454..4d03f66 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -653,8 +653,8 @@
     final MultipathPolicyTracker mMultipathPolicyTracker;
 
     @VisibleForTesting
-    final Map<IConnectivityDiagnosticsCallback, ConnectivityDiagnosticsCallbackInfo>
-            mConnectivityDiagnosticsCallbacks = new HashMap<>();
+    final Map<IBinder, ConnectivityDiagnosticsCallbackInfo> mConnectivityDiagnosticsCallbacks =
+            new HashMap<>();
 
     /**
      * Implements support for the legacy "one network per network type" model.
@@ -7826,11 +7826,12 @@
         ensureRunningOnConnectivityServiceThread();
 
         final IConnectivityDiagnosticsCallback cb = cbInfo.mCb;
+        final IBinder iCb = cb.asBinder();
         final NetworkRequestInfo nri = cbInfo.mRequestInfo;
 
         // This means that the client registered the same callback multiple times. Do
         // not override the previous entry, and exit silently.
-        if (mConnectivityDiagnosticsCallbacks.containsKey(cb)) {
+        if (mConnectivityDiagnosticsCallbacks.containsKey(iCb)) {
             if (VDBG) log("Diagnostics callback is already registered");
 
             // Decrement the reference count for this NetworkRequestInfo. The reference count is
@@ -7840,10 +7841,10 @@
             return;
         }
 
-        mConnectivityDiagnosticsCallbacks.put(cb, cbInfo);
+        mConnectivityDiagnosticsCallbacks.put(iCb, cbInfo);
 
         try {
-            cb.asBinder().linkToDeath(cbInfo, 0);
+            iCb.linkToDeath(cbInfo, 0);
         } catch (RemoteException e) {
             cbInfo.binderDied();
             return;
@@ -7880,13 +7881,14 @@
     private void handleUnregisterConnectivityDiagnosticsCallback(
             @NonNull IConnectivityDiagnosticsCallback cb, int uid) {
         ensureRunningOnConnectivityServiceThread();
+        final IBinder iCb = cb.asBinder();
 
-        if (!mConnectivityDiagnosticsCallbacks.containsKey(cb)) {
+        if (!mConnectivityDiagnosticsCallbacks.containsKey(iCb)) {
             if (VDBG) log("Removing diagnostics callback that is not currently registered");
             return;
         }
 
-        final NetworkRequestInfo nri = mConnectivityDiagnosticsCallbacks.get(cb).mRequestInfo;
+        final NetworkRequestInfo nri = mConnectivityDiagnosticsCallbacks.get(iCb).mRequestInfo;
 
         if (uid != nri.mUid) {
             if (VDBG) loge("Different uid than registrant attempting to unregister cb");
@@ -7898,7 +7900,9 @@
         // enforceRequestCountLimit().
         decrementNetworkRequestPerUidCount(nri);
 
-        cb.asBinder().unlinkToDeath(mConnectivityDiagnosticsCallbacks.remove(cb), 0);
+        final ConnectivityDiagnosticsCallbackInfo cbInfo =
+                mConnectivityDiagnosticsCallbacks.remove(iCb);
+        iCb.unlinkToDeath(cbInfo, 0);
     }
 
     private void handleNetworkTestedWithExtras(
@@ -7973,14 +7977,14 @@
     private List<IConnectivityDiagnosticsCallback> getMatchingPermissionedCallbacks(
             @NonNull NetworkAgentInfo nai) {
         final List<IConnectivityDiagnosticsCallback> results = new ArrayList<>();
-        for (Entry<IConnectivityDiagnosticsCallback, ConnectivityDiagnosticsCallbackInfo> entry :
+        for (Entry<IBinder, ConnectivityDiagnosticsCallbackInfo> entry :
                 mConnectivityDiagnosticsCallbacks.entrySet()) {
             final ConnectivityDiagnosticsCallbackInfo cbInfo = entry.getValue();
             final NetworkRequestInfo nri = cbInfo.mRequestInfo;
             if (nai.satisfies(nri.request)) {
                 if (checkConnectivityDiagnosticsPermissions(
                         nri.mPid, nri.mUid, nai, cbInfo.mCallingPackageName)) {
-                    results.add(entry.getKey());
+                    results.add(entry.getValue().mCb);
                 }
             }
         }
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index 1d7d3c0..7b9d2bd 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -6744,16 +6744,12 @@
 
         verify(mIBinder).linkToDeath(any(ConnectivityDiagnosticsCallbackInfo.class), anyInt());
         verify(mConnectivityDiagnosticsCallback).asBinder();
-        assertTrue(
-                mService.mConnectivityDiagnosticsCallbacks.containsKey(
-                        mConnectivityDiagnosticsCallback));
+        assertTrue(mService.mConnectivityDiagnosticsCallbacks.containsKey(mIBinder));
 
         mService.unregisterConnectivityDiagnosticsCallback(mConnectivityDiagnosticsCallback);
         verify(mIBinder, timeout(TIMEOUT_MS))
                 .unlinkToDeath(any(ConnectivityDiagnosticsCallbackInfo.class), anyInt());
-        assertFalse(
-                mService.mConnectivityDiagnosticsCallbacks.containsKey(
-                        mConnectivityDiagnosticsCallback));
+        assertFalse(mService.mConnectivityDiagnosticsCallbacks.containsKey(mIBinder));
         verify(mConnectivityDiagnosticsCallback, atLeastOnce()).asBinder();
     }
 
@@ -6771,9 +6767,7 @@
 
         verify(mIBinder).linkToDeath(any(ConnectivityDiagnosticsCallbackInfo.class), anyInt());
         verify(mConnectivityDiagnosticsCallback).asBinder();
-        assertTrue(
-                mService.mConnectivityDiagnosticsCallbacks.containsKey(
-                        mConnectivityDiagnosticsCallback));
+        assertTrue(mService.mConnectivityDiagnosticsCallbacks.containsKey(mIBinder));
 
         // Register the same callback again
         mService.registerConnectivityDiagnosticsCallback(
@@ -6782,9 +6776,7 @@
         // Block until all other events are done processing.
         HandlerUtilsKt.waitForIdle(mCsHandlerThread, TIMEOUT_MS);
 
-        assertTrue(
-                mService.mConnectivityDiagnosticsCallbacks.containsKey(
-                        mConnectivityDiagnosticsCallback));
+        assertTrue(mService.mConnectivityDiagnosticsCallbacks.containsKey(mIBinder));
     }
 
     @Test