Store the SubscriptionId together with the Carrier UID

Always store the corresponding SubscriptionId together with the Carrier
uid so that when the carrier loses the Privilege, both the uid and
SubscriptionID will be used in the onCarrierPrivilegesLost callback

Bug: 324357121
Test: atest ConnectivityCoverageTests:android.net.connectivity.com.android.server.connectivity.CarrierPrivilegeAuthenticatorTest
      atest ConnectivityCoverageTests:android.net.connectivity.com.android.server.ConnectivityServiceTest
Change-Id: I28e51c583261a67d4441c6f825ade6781b862ee4
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index e6287bc..3d646fd 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -114,7 +114,6 @@
 import static com.android.net.module.util.PermissionUtils.enforceNetworkStackPermissionOr;
 import static com.android.net.module.util.PermissionUtils.hasAnyPermissionOf;
 import static com.android.server.ConnectivityStatsLog.CONNECTIVITY_STATE_SAMPLE;
-import static com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener;
 import static com.android.server.connectivity.ConnectivityFlags.REQUEST_RESTRICTED_WIFI;
 
 import android.Manifest;
@@ -257,6 +256,7 @@
 import android.stats.connectivity.ValidatedState;
 import android.sysprop.NetworkProperties;
 import android.system.ErrnoException;
+import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
 import android.text.TextUtils;
 import android.util.ArrayMap;
@@ -377,6 +377,7 @@
 import java.util.TreeSet;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 
 /**
@@ -1287,18 +1288,14 @@
     }
     private final LegacyTypeTracker mLegacyTypeTracker = new LegacyTypeTracker(this);
 
-    private final CarrierPrivilegesLostListenerImpl mCarrierPrivilegesLostListenerImpl =
-            new CarrierPrivilegesLostListenerImpl();
-
-    private class CarrierPrivilegesLostListenerImpl implements CarrierPrivilegesLostListener {
-        @Override
-        public void onCarrierPrivilegesLost(int uid) {
-            if (mRequestRestrictedWifiEnabled) {
-                mHandler.sendMessage(mHandler.obtainMessage(
-                        EVENT_UID_CARRIER_PRIVILEGES_LOST, uid, 0 /* arg2 */));
-            }
+    @VisibleForTesting
+    void onCarrierPrivilegesLost(Integer uid, Integer subId) {
+        if (mRequestRestrictedWifiEnabled) {
+            mHandler.sendMessage(mHandler.obtainMessage(
+                    EVENT_UID_CARRIER_PRIVILEGES_LOST, uid, subId));
         }
     }
+
     final LocalPriorityDump mPriorityDumper = new LocalPriorityDump();
     /**
      * Helper class which parses out priority arguments and dumps sections according to their
@@ -1357,11 +1354,6 @@
         }
     }
 
-    @VisibleForTesting
-    CarrierPrivilegesLostListener getCarrierPrivilegesLostListener() {
-        return mCarrierPrivilegesLostListenerImpl;
-    }
-
     /**
      * Dependencies of ConnectivityService, for injection in tests.
      */
@@ -1525,7 +1517,7 @@
                 @NonNull final Context context,
                 @NonNull final TelephonyManager tm,
                 boolean requestRestrictedWifiEnabled,
-                @NonNull CarrierPrivilegesLostListener listener) {
+                @NonNull BiConsumer<Integer, Integer> listener) {
             if (isAtLeastT()) {
                 return new CarrierPrivilegeAuthenticator(
                         context, tm, requestRestrictedWifiEnabled, listener);
@@ -1813,7 +1805,7 @@
                 && mDeps.isFeatureEnabled(context, REQUEST_RESTRICTED_WIFI);
         mCarrierPrivilegeAuthenticator = mDeps.makeCarrierPrivilegeAuthenticator(
                 mContext, mTelephonyManager, mRequestRestrictedWifiEnabled,
-                mCarrierPrivilegesLostListenerImpl);
+                this::onCarrierPrivilegesLost);
 
         if (mDeps.isAtLeastU()
                 && mDeps
@@ -5401,6 +5393,13 @@
         return false;
     }
 
+    private int getSubscriptionIdFromNetworkCaps(@NonNull final NetworkCapabilities caps) {
+        if (mCarrierPrivilegeAuthenticator != null) {
+            return mCarrierPrivilegeAuthenticator.getSubIdFromNetworkCapabilities(caps);
+        }
+        return SubscriptionManager.INVALID_SUBSCRIPTION_ID;
+    }
+
     private void handleRegisterNetworkRequestWithIntent(@NonNull final Message msg) {
         final NetworkRequestInfo nri = (NetworkRequestInfo) (msg.obj);
         // handleRegisterNetworkRequestWithIntent() doesn't apply to multilayer requests.
@@ -6492,7 +6491,7 @@
                     handleFrozenUids(args.mUids, args.mFrozenStates);
                     break;
                 case EVENT_UID_CARRIER_PRIVILEGES_LOST:
-                    handleUidCarrierPrivilegesLost(msg.arg1);
+                    handleUidCarrierPrivilegesLost(msg.arg1, msg.arg2);
                     break;
             }
         }
@@ -9155,7 +9154,7 @@
         }
     }
 
-    private void handleUidCarrierPrivilegesLost(int uid) {
+    private void handleUidCarrierPrivilegesLost(int uid, int subId) {
         ensureRunningOnConnectivityServiceThread();
         // A NetworkRequest needs to be revoked when all the conditions are met
         //   1. It requests restricted network
@@ -9166,6 +9165,7 @@
             if ((nr.isRequest() || nr.isListen())
                     && !nr.hasCapability(NET_CAPABILITY_NOT_RESTRICTED)
                     && nr.getRequestorUid() == uid
+                    && getSubscriptionIdFromNetworkCaps(nr.networkCapabilities) == subId
                     && !hasConnectivityRestrictedNetworksPermission(uid, true)) {
                 declareNetworkRequestUnfulfillable(nr);
             }
@@ -9174,7 +9174,8 @@
         // A NetworkAgent's allowedUids may need to be updated if the app has lost
         // carrier config
         for (final NetworkAgentInfo nai : mNetworkAgentInfos) {
-            if (nai.networkCapabilities.getAllowedUidsNoCopy().contains(uid)) {
+            if (nai.networkCapabilities.getAllowedUidsNoCopy().contains(uid)
+                    && getSubscriptionIdFromNetworkCaps(nai.networkCapabilities) == subId) {
                 final NetworkCapabilities nc = new NetworkCapabilities(nai.networkCapabilities);
                 NetworkAgentInfo.restrictCapabilitiesFromNetworkAgent(
                         nc,
diff --git a/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java b/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
index 533278e..04d0fc1 100644
--- a/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
+++ b/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
@@ -40,12 +40,13 @@
 import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
 import android.util.Log;
-import android.util.SparseIntArray;
+import android.util.SparseArray;
 
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.IndentingPrintWriter;
 import com.android.modules.utils.HandlerExecutor;
+import com.android.modules.utils.build.SdkLevel;
 import com.android.net.module.util.DeviceConfigUtils;
 import com.android.networkstack.apishim.TelephonyManagerShimImpl;
 import com.android.networkstack.apishim.common.TelephonyManagerShim;
@@ -55,6 +56,7 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.Executor;
+import java.util.function.BiConsumer;
 
 /**
  * Tracks the uid of the carrier privileged app that provides the carrier config.
@@ -71,7 +73,8 @@
     private final TelephonyManagerShim mTelephonyManagerShim;
     private final TelephonyManager mTelephonyManager;
     @GuardedBy("mLock")
-    private final SparseIntArray mCarrierServiceUid = new SparseIntArray(2 /* initialCapacity */);
+    private final SparseArray<CarrierServiceUidWithSubId> mCarrierServiceUidWithSubId =
+            new SparseArray<>(2 /* initialCapacity */);
     @GuardedBy("mLock")
     private int mModemCount = 0;
     private final Object mLock = new Object();
@@ -81,14 +84,14 @@
     private final boolean mUseCallbacksForServiceChanged;
     private final boolean mRequestRestrictedWifiEnabled;
     @NonNull
-    private final CarrierPrivilegesLostListener mListener;
+    private final BiConsumer<Integer, Integer> mListener;
 
     public CarrierPrivilegeAuthenticator(@NonNull final Context c,
             @NonNull final Dependencies deps,
             @NonNull final TelephonyManager t,
             @NonNull final TelephonyManagerShim telephonyManagerShim,
             final boolean requestRestrictedWifiEnabled,
-            @NonNull CarrierPrivilegesLostListener listener) {
+            @NonNull BiConsumer<Integer, Integer> listener) {
         mContext = c;
         mTelephonyManager = t;
         mTelephonyManagerShim = telephonyManagerShim;
@@ -121,7 +124,7 @@
 
     public CarrierPrivilegeAuthenticator(@NonNull final Context c,
             @NonNull final TelephonyManager t, final boolean requestRestrictedWifiEnabled,
-            @NonNull CarrierPrivilegesLostListener listener) {
+            @NonNull BiConsumer<Integer, Integer> listener) {
         this(c, new Dependencies(), t, TelephonyManagerShimImpl.newInstance(t),
                 requestRestrictedWifiEnabled, listener);
     }
@@ -142,18 +145,6 @@
         }
     }
 
-    /**
-     * Listener interface to get a notification when the carrier App lost its privileges.
-     */
-    public interface CarrierPrivilegesLostListener {
-        /**
-         * Called when the carrier App lost its privileges.
-         *
-         * @param uid  The uid of the carrier app which has lost its privileges.
-         */
-        void onCarrierPrivilegesLost(int uid);
-    }
-
     private void simConfigChanged() {
         synchronized (mLock) {
             unregisterCarrierPrivilegesListeners();
@@ -163,6 +154,29 @@
         }
     }
 
+    private static class CarrierServiceUidWithSubId {
+        final int mUid;
+        final int mSubId;
+
+        CarrierServiceUidWithSubId(int uid, int subId) {
+            mUid = uid;
+            mSubId = subId;
+        }
+
+        @Override
+        public boolean equals(Object obj) {
+            if (!(obj instanceof CarrierServiceUidWithSubId)) {
+                return false;
+            }
+            CarrierServiceUidWithSubId compare = (CarrierServiceUidWithSubId) obj;
+            return (mUid == compare.mUid && mSubId == compare.mSubId);
+        }
+
+        @Override
+        public int hashCode() {
+            return mUid * 31 + mSubId;
+        }
+    }
     private class PrivilegeListener implements CarrierPrivilegesListenerShim {
         public final int mLogicalSlot;
 
@@ -192,10 +206,17 @@
                 return;
             }
             synchronized (mLock) {
-                int oldUid = mCarrierServiceUid.get(mLogicalSlot);
-                mCarrierServiceUid.put(mLogicalSlot, carrierServiceUid);
-                if (oldUid != 0 && oldUid != carrierServiceUid) {
-                    mListener.onCarrierPrivilegesLost(oldUid);
+                CarrierServiceUidWithSubId oldPair =
+                        mCarrierServiceUidWithSubId.get(mLogicalSlot);
+                int subId = getSubId(mLogicalSlot);
+                mCarrierServiceUidWithSubId.put(
+                        mLogicalSlot,
+                        new CarrierServiceUidWithSubId(carrierServiceUid, subId));
+                if (oldPair != null
+                        && oldPair.mUid != Process.INVALID_UID
+                        && oldPair.mSubId != SubscriptionManager.INVALID_SUBSCRIPTION_ID
+                        && !oldPair.equals(mCarrierServiceUidWithSubId.get(mLogicalSlot))) {
+                    mListener.accept(oldPair.mUid, oldPair.mSubId);
                 }
             }
         }
@@ -218,10 +239,13 @@
     private void unregisterCarrierPrivilegesListeners() {
         for (PrivilegeListener carrierPrivilegesListener : mCarrierPrivilegesChangedListeners) {
             removeCarrierPrivilegesListener(carrierPrivilegesListener);
-            int oldUid = mCarrierServiceUid.get(carrierPrivilegesListener.mLogicalSlot);
-            mCarrierServiceUid.delete(carrierPrivilegesListener.mLogicalSlot);
-            if (oldUid != 0) {
-                mListener.onCarrierPrivilegesLost(oldUid);
+            CarrierServiceUidWithSubId oldPair =
+                    mCarrierServiceUidWithSubId.get(carrierPrivilegesListener.mLogicalSlot);
+            mCarrierServiceUidWithSubId.remove(carrierPrivilegesListener.mLogicalSlot);
+            if (oldPair != null
+                    && oldPair.mUid != Process.INVALID_UID
+                    && oldPair.mSubId != SubscriptionManager.INVALID_SUBSCRIPTION_ID) {
+                mListener.accept(oldPair.mUid, oldPair.mSubId);
             }
         }
         mCarrierPrivilegesChangedListeners.clear();
@@ -259,7 +283,23 @@
      */
     public boolean isCarrierServiceUidForNetworkCapabilities(int callingUid,
             @NonNull NetworkCapabilities networkCapabilities) {
-        if (callingUid == Process.INVALID_UID) return false;
+        if (callingUid == Process.INVALID_UID) {
+            return false;
+        }
+        int subId = getSubIdFromNetworkCapabilities(networkCapabilities);
+        if (SubscriptionManager.INVALID_SUBSCRIPTION_ID == subId) {
+            return false;
+        }
+        return callingUid == getCarrierServiceUidForSubId(subId);
+    }
+
+    /**
+     * Extract the SubscriptionId from the NetworkCapabilities.
+     *
+     * @param networkCapabilities the network capabilities which may contains the SubscriptionId.
+     * @return the SubscriptionId.
+     */
+    public int getSubIdFromNetworkCapabilities(@NonNull NetworkCapabilities networkCapabilities) {
         int subId;
         if (networkCapabilities.hasSingleTransportBesidesTest(TRANSPORT_CELLULAR)) {
             subId = getSubIdFromTelephonySpecifier(networkCapabilities.getNetworkSpecifier());
@@ -285,21 +325,42 @@
             Log.wtf(TAG, "NetworkCapabilities subIds are inconsistent between "
                     + "specifier/transportInfo and mSubIds : " + networkCapabilities);
         }
-        if (SubscriptionManager.INVALID_SUBSCRIPTION_ID == subId) return false;
-        return callingUid == getCarrierServiceUidForSubId(subId);
+        return subId;
+    }
+
+    @VisibleForTesting
+    protected int getSubId(int slotIndex) {
+        if (SdkLevel.isAtLeastU()) {
+            return SubscriptionManager.getSubscriptionId(slotIndex);
+        } else {
+            SubscriptionManager sm = mContext.getSystemService(SubscriptionManager.class);
+            int[] subIds = sm.getSubscriptionIds(slotIndex);
+            if (subIds != null && subIds.length > 0) {
+                return subIds[0];
+            }
+            return SubscriptionManager.INVALID_SUBSCRIPTION_ID;
+        }
     }
 
     @VisibleForTesting
     void updateCarrierServiceUid() {
         synchronized (mLock) {
-            SparseIntArray oldCarrierServiceUid = mCarrierServiceUid.clone();
-            mCarrierServiceUid.clear();
+            SparseArray<CarrierServiceUidWithSubId> copy = mCarrierServiceUidWithSubId.clone();
+            mCarrierServiceUidWithSubId.clear();
             for (int i = 0; i < mModemCount; i++) {
-                mCarrierServiceUid.put(i, getCarrierServicePackageUidForSlot(i));
+                int subId = getSubId(i);
+                mCarrierServiceUidWithSubId.put(
+                        i,
+                        new CarrierServiceUidWithSubId(
+                                getCarrierServicePackageUidForSlot(i), subId));
             }
-            for (int i = 0; i < oldCarrierServiceUid.size(); i++) {
-                if (mCarrierServiceUid.indexOfValue(oldCarrierServiceUid.valueAt(i)) < 0) {
-                    mListener.onCarrierPrivilegesLost(oldCarrierServiceUid.valueAt(i));
+            for (int i = 0; i < copy.size(); ++i) {
+                CarrierServiceUidWithSubId oldPair = copy.valueAt(i);
+                CarrierServiceUidWithSubId newPair = mCarrierServiceUidWithSubId.get(copy.keyAt(i));
+                if (oldPair.mUid != Process.INVALID_UID
+                        && oldPair.mSubId != SubscriptionManager.INVALID_SUBSCRIPTION_ID
+                        && !oldPair.equals(newPair)) {
+                    mListener.accept(oldPair.mUid, oldPair.mSubId);
                 }
             }
         }
@@ -307,18 +368,17 @@
 
     @VisibleForTesting
     int getCarrierServiceUidForSubId(int subId) {
-        final int slotId = getSlotIndex(subId);
         synchronized (mLock) {
-            return mCarrierServiceUid.get(slotId, Process.INVALID_UID);
+            for (int i = 0; i < mCarrierServiceUidWithSubId.size(); ++i) {
+                if (mCarrierServiceUidWithSubId.valueAt(i).mSubId == subId) {
+                    return mCarrierServiceUidWithSubId.valueAt(i).mUid;
+                }
+            }
+            return Process.INVALID_UID;
         }
     }
 
     @VisibleForTesting
-    protected int getSlotIndex(int subId) {
-        return SubscriptionManager.getSlotIndex(subId);
-    }
-
-    @VisibleForTesting
     int getUidForPackage(String pkgName) {
         if (pkgName == null) {
             return Process.INVALID_UID;
@@ -383,11 +443,12 @@
         pw.println("CarrierPrivilegeAuthenticator:");
         pw.println("mRequestRestrictedWifiEnabled = " + mRequestRestrictedWifiEnabled);
         synchronized (mLock) {
-            final int size = mCarrierServiceUid.size();
-            for (int i = 0; i < size; ++i) {
-                final int logicalSlot = mCarrierServiceUid.keyAt(i);
-                final int serviceUid = mCarrierServiceUid.valueAt(i);
-                pw.println("Logical slot = " + logicalSlot + " : uid = " + serviceUid);
+            for (int i = 0; i < mCarrierServiceUidWithSubId.size(); ++i) {
+                final int logicalSlot = mCarrierServiceUidWithSubId.keyAt(i);
+                final int serviceUid = mCarrierServiceUidWithSubId.valueAt(i).mUid;
+                final int subId = mCarrierServiceUidWithSubId.valueAt(i).mSubId;
+                pw.println("Logical slot = " + logicalSlot + " : uid = " + serviceUid
+                        + " : subId = " + subId);
             }
         }
     }
diff --git a/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt b/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
index 9148770..361d68c 100644
--- a/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
+++ b/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
@@ -56,7 +56,6 @@
 import com.android.server.NetworkAgentWrapper
 import com.android.server.TestNetIdManager
 import com.android.server.connectivity.CarrierPrivilegeAuthenticator
-import com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener
 import com.android.server.connectivity.ConnectivityResources
 import com.android.server.connectivity.MockableSystemProperties
 import com.android.server.connectivity.MultinetworkPolicyTracker
@@ -89,6 +88,7 @@
 import org.mockito.MockitoAnnotations
 import org.mockito.Spy
 import java.util.function.Consumer
+import java.util.function.BiConsumer
 
 const val SERVICE_BIND_TIMEOUT_MS = 5_000L
 const val TEST_TIMEOUT_MS = 10_000L
@@ -245,7 +245,7 @@
             context: Context,
             tm: TelephonyManager,
             requestRestrictedWifiEnabled: Boolean,
-            listener: CarrierPrivilegesLostListener
+            listener: BiConsumer<Int, Int>
         ): CarrierPrivilegeAuthenticator {
             return CarrierPrivilegeAuthenticator(context,
                 object : CarrierPrivilegeAuthenticator.Dependencies() {
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 6623bbd..c534025 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -173,7 +173,6 @@
 import static com.android.server.ConnectivityServiceTestUtils.transportToLegacyType;
 import static com.android.server.NetworkAgentWrapper.CallbackType.OnQosCallbackRegister;
 import static com.android.server.NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister;
-import static com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener;
 import static com.android.testutils.Cleanup.testAndCleanup;
 import static com.android.testutils.ConcurrentUtils.await;
 import static com.android.testutils.ConcurrentUtils.durationOf;
@@ -488,6 +487,7 @@
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
 import java.util.function.Supplier;
@@ -526,7 +526,7 @@
     // between a LOST callback that arrives immediately and a LOST callback that arrives after
     // the linger/nascent timeout. For this, our assertions should run fast enough to leave
     // less than (mService.mLingerDelayMs - TEST_CALLBACK_TIMEOUT_MS) between the time callbacks are
-    // supposedly fired, and the time we call expectCallback.
+    // supposedly fired, and the time we call expectCapChanged.
     private static final int TEST_CALLBACK_TIMEOUT_MS = 250;
     // Chosen to be less than TEST_CALLBACK_TIMEOUT_MS. This ensures that requests have time to
     // complete before callbacks are verified.
@@ -565,6 +565,7 @@
     private static final int TEST_PACKAGE_UID2 = 321;
     private static final int TEST_PACKAGE_UID3 = 456;
     private static final int NETWORK_ACTIVITY_NO_UID = -1;
+    private static final int TEST_SUBSCRIPTION_ID = 1;
 
     private static final int PACKET_WAKEUP_MARK_MASK = 0x80000000;
 
@@ -2059,7 +2060,7 @@
                 @NonNull final Context context,
                 @NonNull final TelephonyManager tm,
                 final boolean requestRestrictedWifiEnabled,
-                CarrierPrivilegesLostListener listener) {
+                BiConsumer<Integer, Integer> listener) {
             return mDeps.isAtLeastT() ? mCarrierPrivilegeAuthenticator : null;
         }
 
@@ -11486,7 +11487,7 @@
         doTestInterfaceClassActivityChanged(TRANSPORT_CELLULAR);
     }
 
-    private void doTestOnNetworkActive_NewNetworkConnects(int transportType, boolean expectCallback)
+    private void doTestOnNetworkActive_NewNetworkConnects(int transportType, boolean expectCapChanged)
             throws Exception {
         final ConditionVariable onNetworkActiveCv = new ConditionVariable();
         final ConnectivityManager.OnNetworkActiveListener listener = onNetworkActiveCv::open;
@@ -11498,7 +11499,7 @@
         testAndCleanup(() -> {
             mCm.addDefaultNetworkActiveListener(listener);
             agent.connect(true);
-            if (expectCallback) {
+            if (expectCapChanged) {
                 assertTrue(onNetworkActiveCv.block(TEST_CALLBACK_TIMEOUT_MS));
             } else {
                 assertFalse(onNetworkActiveCv.block(TEST_CALLBACK_TIMEOUT_MS));
@@ -11513,7 +11514,7 @@
 
     @Test
     public void testOnNetworkActive_NewCellConnects_CallbackCalled() throws Exception {
-        doTestOnNetworkActive_NewNetworkConnects(TRANSPORT_CELLULAR, true /* expectCallback */);
+        doTestOnNetworkActive_NewNetworkConnects(TRANSPORT_CELLULAR, true /* expectCapChanged */);
     }
 
     @Test
@@ -11522,8 +11523,8 @@
         // networks that tracker adds the idle timer to. And the tracker does not set the idle timer
         // for the ethernet network.
         // So onNetworkActive is not called when the ethernet becomes the default network
-        final boolean expectCallback = mDeps.isAtLeastV();
-        doTestOnNetworkActive_NewNetworkConnects(TRANSPORT_ETHERNET, expectCallback);
+        final boolean expectCapChanged = mDeps.isAtLeastV();
+        doTestOnNetworkActive_NewNetworkConnects(TRANSPORT_ETHERNET, expectCapChanged);
     }
 
     @Test
@@ -17375,7 +17376,7 @@
         return new NetworkRequest.Builder()
             .addTransportType(NetworkCapabilities.TRANSPORT_WIFI)
             .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)
-            .setSubscriptionIds(Collections.singleton(Process.myUid()))
+            .setSubscriptionIds(Collections.singleton(TEST_SUBSCRIPTION_ID))
             .build();
     }
 
@@ -17422,32 +17423,46 @@
         final NetworkCallback networkCallback1 = new NetworkCallback();
         final NetworkCallback networkCallback2 = new NetworkCallback();
 
-        mCm.requestNetwork(getRestrictedRequestForWifiWithSubIds(), networkCallback1);
-        mCm.requestNetwork(getRestrictedRequestForWifiWithSubIds(), pendingIntent);
-        mCm.registerNetworkCallback(getRestrictedRequestForWifiWithSubIds(), networkCallback2);
+        mCm.requestNetwork(
+                getRestrictedRequestForWifiWithSubIds(), networkCallback1);
+        mCm.requestNetwork(
+                getRestrictedRequestForWifiWithSubIds(), pendingIntent);
+        mCm.registerNetworkCallback(
+                getRestrictedRequestForWifiWithSubIds(), networkCallback2);
 
         mCm.unregisterNetworkCallback(networkCallback1);
         mCm.releaseNetworkRequest(pendingIntent);
         mCm.unregisterNetworkCallback(networkCallback2);
     }
 
-    @Test
-    @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
-    public void testRestrictedRequestRemovedDueToCarrierPrivilegesLost() throws Exception {
-        mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_DENIED);
-        NetworkCapabilities filter = getRestrictedRequestForWifiWithSubIds().networkCapabilities;
+    private void doTestNetworkRequestWithCarrierPrivilegesLost(
+            boolean shouldGrantRestrictedNetworkPermission,
+            int lostPrivilegeUid,
+            int lostPrivilegeSubId,
+            boolean expectUnavailable,
+            boolean expectCapChanged) throws Exception {
+        if (shouldGrantRestrictedNetworkPermission) {
+            mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_GRANTED);
+        } else {
+            mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_DENIED);
+        }
+
+        NetworkCapabilities filter =
+                getRestrictedRequestForWifiWithSubIds().networkCapabilities;
         final HandlerThread handlerThread = new HandlerThread("testRestrictedFactoryRequests");
         handlerThread.start();
+
         final MockNetworkFactory testFactory = new MockNetworkFactory(handlerThread.getLooper(),
                 mServiceContext, "testFactory", filter, mCsHandlerThread);
         testFactory.register();
-
         testFactory.assertRequestCountEquals(0);
+
         doReturn(true).when(mCarrierPrivilegeAuthenticator)
                 .isCarrierServiceUidForNetworkCapabilities(eq(Process.myUid()), any());
-        final TestNetworkCallback networkCallback1 = new TestNetworkCallback();
-        final NetworkRequest networkrequest1 = getRestrictedRequestForWifiWithSubIds();
-        mCm.requestNetwork(networkrequest1, networkCallback1);
+        final TestNetworkCallback networkCallback = new TestNetworkCallback();
+        final NetworkRequest networkrequest =
+                getRestrictedRequestForWifiWithSubIds();
+        mCm.requestNetwork(networkrequest, networkCallback);
         testFactory.expectRequestAdd();
         testFactory.assertRequestCountEquals(1);
 
@@ -17455,24 +17470,36 @@
                 .setAllowedUids(Set.of(Process.myUid()))
                 .build();
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI, new LinkProperties(), nc);
-        mWiFiAgent.connect(true);
-        networkCallback1.expectAvailableThenValidatedCallbacks(mWiFiAgent);
-
+        mWiFiAgent.connect(false);
+        networkCallback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         final NetworkAgentInfo nai = mService.getNetworkAgentInfoForNetwork(
                 mWiFiAgent.getNetwork());
 
         doReturn(false).when(mCarrierPrivilegeAuthenticator)
                 .isCarrierServiceUidForNetworkCapabilities(eq(Process.myUid()), any());
-        final CarrierPrivilegesLostListener carrierPrivilegesLostListener =
-                mService.getCarrierPrivilegesLostListener();
-        carrierPrivilegesLostListener.onCarrierPrivilegesLost(Process.myUid());
+        doReturn(TEST_SUBSCRIPTION_ID).when(mCarrierPrivilegeAuthenticator)
+                .getSubIdFromNetworkCapabilities(any());
+        mService.onCarrierPrivilegesLost(lostPrivilegeUid, lostPrivilegeSubId);
         waitForIdle();
 
-        testFactory.expectRequestRemove();
-        testFactory.assertRequestCountEquals(0);
-        assertTrue(nai.networkCapabilities.getAllowedUidsNoCopy().isEmpty());
-        networkCallback1.expect(NETWORK_CAPS_UPDATED);
-        networkCallback1.expect(UNAVAILABLE);
+        if (expectCapChanged) {
+            networkCallback.expect(NETWORK_CAPS_UPDATED);
+        }
+        if (expectUnavailable) {
+            networkCallback.expect(UNAVAILABLE);
+        }
+        if (!expectCapChanged && !expectUnavailable) {
+            networkCallback.assertNoCallback();
+        }
+
+        mWiFiAgent.disconnect();
+        waitForIdle();
+
+        if (expectUnavailable) {
+            testFactory.assertRequestCountEquals(0);
+        } else {
+            testFactory.assertRequestCountEquals(1);
+        }
 
         handlerThread.quitSafely();
         handlerThread.join();
@@ -17480,64 +17507,45 @@
 
     @Test
     @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+    public void testRestrictedRequestRemovedDueToCarrierPrivilegesLost() throws Exception {
+        doTestNetworkRequestWithCarrierPrivilegesLost(
+                false /* shouldGrantRestrictedNetworkPermission */,
+                Process.myUid(),
+                TEST_SUBSCRIPTION_ID,
+                true /* expectUnavailable */,
+                true /* expectCapChanged */);
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+    public void testRequestNotRemoved_MismatchSubId() throws Exception {
+        doTestNetworkRequestWithCarrierPrivilegesLost(
+                false /* shouldGrantRestrictedNetworkPermission */,
+                Process.myUid(),
+                TEST_SUBSCRIPTION_ID + 1,
+                false /* expectUnavailable */,
+                false /* expectCapChanged */);
+    }
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
     public void testRequestNotRemoved_MismatchUid() throws Exception {
-        mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_DENIED);
-        NetworkCapabilities filter = getRestrictedRequestForWifiWithSubIds().networkCapabilities;
-        final HandlerThread handlerThread = new HandlerThread("testRestrictedFactoryRequests");
-        handlerThread.start();
-
-        final MockNetworkFactory testFactory = new MockNetworkFactory(handlerThread.getLooper(),
-                mServiceContext, "testFactory", filter, mCsHandlerThread);
-        testFactory.register();
-
-        doReturn(true).when(mCarrierPrivilegeAuthenticator)
-                .isCarrierServiceUidForNetworkCapabilities(anyInt(), any());
-        final TestNetworkCallback networkCallback1 = new TestNetworkCallback();
-        final NetworkRequest networkrequest1 = getRestrictedRequestForWifiWithSubIds();
-        mCm.requestNetwork(networkrequest1, networkCallback1);
-        testFactory.expectRequestAdd();
-        testFactory.assertRequestCountEquals(1);
-
-        doReturn(false).when(mCarrierPrivilegeAuthenticator)
-                .isCarrierServiceUidForNetworkCapabilities(eq(Process.myUid()), any());
-        final CarrierPrivilegesLostListener carrierPrivilegesLostListener =
-                mService.getCarrierPrivilegesLostListener();
-        carrierPrivilegesLostListener.onCarrierPrivilegesLost(Process.myUid() + 1);
-        expectNoRequestChanged(testFactory);
-
-        handlerThread.quitSafely();
-        handlerThread.join();
+        doTestNetworkRequestWithCarrierPrivilegesLost(
+                false /* shouldGrantRestrictedNetworkPermission */,
+                Process.myUid() + 1,
+                TEST_SUBSCRIPTION_ID,
+                false /* expectUnavailable */,
+                false /* expectCapChanged */);
     }
 
     @Test
     @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
     public void testRequestNotRemoved_HasRestrictedNetworkPermission() throws Exception {
-        mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_GRANTED);
-        NetworkCapabilities filter = getRestrictedRequestForWifiWithSubIds().networkCapabilities;
-        final HandlerThread handlerThread = new HandlerThread("testRestrictedFactoryRequests");
-        handlerThread.start();
-
-        final MockNetworkFactory testFactory = new MockNetworkFactory(handlerThread.getLooper(),
-                mServiceContext, "testFactory", filter, mCsHandlerThread);
-        testFactory.register();
-
-        doReturn(true).when(mCarrierPrivilegeAuthenticator)
-            .isCarrierServiceUidForNetworkCapabilities(anyInt(), any());
-        final TestNetworkCallback networkCallback1 = new TestNetworkCallback();
-        final NetworkRequest networkrequest1 = getRestrictedRequestForWifiWithSubIds();
-        mCm.requestNetwork(networkrequest1, networkCallback1);
-        testFactory.expectRequestAdd();
-        testFactory.assertRequestCountEquals(1);
-
-        doReturn(false).when(mCarrierPrivilegeAuthenticator)
-                .isCarrierServiceUidForNetworkCapabilities(eq(Process.myUid()), any());
-        final CarrierPrivilegesLostListener carrierPrivilegesLostListener =
-                mService.getCarrierPrivilegesLostListener();
-        carrierPrivilegesLostListener.onCarrierPrivilegesLost(Process.myUid());
-        expectNoRequestChanged(testFactory);
-
-        handlerThread.quitSafely();
-        handlerThread.join();
+        doTestNetworkRequestWithCarrierPrivilegesLost(
+                true /* shouldGrantRestrictedNetworkPermission */,
+                Process.myUid(),
+                TEST_SUBSCRIPTION_ID,
+                false /* expectUnavailable */,
+                true /* expectCapChanged */);
     }
     @Test
     public void testAllowedUids() throws Exception {
diff --git a/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java b/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
index 9f0ec30..7bd2b56 100644
--- a/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
+++ b/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
@@ -20,7 +20,6 @@
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
 import static android.telephony.TelephonyManager.ACTION_MULTI_SIM_CONFIG_CHANGED;
 
-import static com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener;
 import static com.android.server.connectivity.ConnectivityFlags.CARRIER_SERVICE_CHANGED_USE_CALLBACK;
 
 import static org.junit.Assert.assertEquals;
@@ -47,7 +46,6 @@
 import android.net.TelephonyNetworkSpecifier;
 import android.os.Build;
 import android.os.HandlerThread;
-import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
 
 import com.android.net.module.util.CollectionUtils;
@@ -71,6 +69,7 @@
 import java.util.Collections;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.BiConsumer;
 
 /**
  * Tests for CarrierPrivilegeAuthenticatorTest.
@@ -92,7 +91,7 @@
     @NonNull private final TelephonyManagerShimImpl mTelephonyManagerShim;
     @NonNull private final PackageManager mPackageManager;
     @NonNull private TestCarrierPrivilegeAuthenticator mCarrierPrivilegeAuthenticator;
-    @NonNull private final CarrierPrivilegesLostListener mListener;
+    @NonNull private final BiConsumer<Integer, Integer> mListener;
     private final int mCarrierConfigPkgUid = 12345;
     private final boolean mUseCallbacks;
     private final String mTestPkg = "com.android.server.connectivity.test";
@@ -107,9 +106,8 @@
                     mListener);
         }
         @Override
-        protected int getSlotIndex(int subId) {
-            if (SubscriptionManager.DEFAULT_SUBSCRIPTION_ID == subId) return TEST_SUBSCRIPTION_ID;
-            return subId;
+        protected int getSubId(int slotIndex) {
+            return TEST_SUBSCRIPTION_ID;
         }
     }
 
@@ -129,7 +127,7 @@
         mTelephonyManager = mock(TelephonyManager.class);
         mTelephonyManagerShim = mock(TelephonyManagerShimImpl.class);
         mPackageManager = mock(PackageManager.class);
-        mListener = mock(CarrierPrivilegesLostListener.class);
+        mListener = mock(BiConsumer.class);
         mHandlerThread = new HandlerThread(CarrierPrivilegeAuthenticatorTest.class.getSimpleName());
         mUseCallbacks = useCallbacks;
         final Dependencies deps = mock(Dependencies.class);
@@ -184,7 +182,7 @@
 
         final NetworkCapabilities.Builder ncBuilder = new NetworkCapabilities.Builder()
                 .addTransportType(TRANSPORT_CELLULAR)
-                .setNetworkSpecifier(new TelephonyNetworkSpecifier(0));
+                .setNetworkSpecifier(new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID));
 
         assertTrue(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
@@ -220,7 +218,8 @@
 
         newListeners.get(0).onCarrierServiceChanged(null, mCarrierConfigPkgUid);
 
-        final TelephonyNetworkSpecifier specifier = new TelephonyNetworkSpecifier(0);
+        final TelephonyNetworkSpecifier specifier =
+                new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID);
         final NetworkCapabilities nc = new NetworkCapabilities.Builder()
                 .addTransportType(TRANSPORT_CELLULAR)
                 .setNetworkSpecifier(specifier)
@@ -239,7 +238,11 @@
         l.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
         l.onCarrierServiceChanged(null, mCarrierConfigPkgUid + 1);
         if (mUseCallbacks) {
-            verify(mListener).onCarrierPrivilegesLost(eq(mCarrierConfigPkgUid));
+            verify(mListener).accept(eq(mCarrierConfigPkgUid), eq(TEST_SUBSCRIPTION_ID));
+        }
+        l.onCarrierServiceChanged(null, mCarrierConfigPkgUid + 2);
+        if (mUseCallbacks) {
+            verify(mListener).accept(eq(mCarrierConfigPkgUid + 1), eq(TEST_SUBSCRIPTION_ID));
         }
     }
 
@@ -247,7 +250,8 @@
     public void testOnCarrierPrivilegesChanged() throws Exception {
         final CarrierPrivilegesListenerShim listener = getCarrierPrivilegesListeners().get(0);
 
-        final TelephonyNetworkSpecifier specifier = new TelephonyNetworkSpecifier(0);
+        final TelephonyNetworkSpecifier specifier =
+                new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID);
         final NetworkCapabilities nc = new NetworkCapabilities.Builder()
                 .addTransportType(TRANSPORT_CELLULAR)
                 .setNetworkSpecifier(specifier)
@@ -275,7 +279,7 @@
         assertFalse(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
 
-        ncBuilder.setNetworkSpecifier(new TelephonyNetworkSpecifier(0));
+        ncBuilder.setNetworkSpecifier(new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID));
         assertTrue(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
 
@@ -284,7 +288,7 @@
         ncBuilder.setNetworkSpecifier(null);
         ncBuilder.removeTransportType(TRANSPORT_CELLULAR);
         ncBuilder.addTransportType(TRANSPORT_WIFI);
-        ncBuilder.setNetworkSpecifier(new TelephonyNetworkSpecifier(0));
+        ncBuilder.setNetworkSpecifier(new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID));
         assertFalse(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
     }
@@ -298,7 +302,7 @@
         final NetworkCapabilities.Builder ncBuilder = new NetworkCapabilities.Builder();
         ncBuilder.addTransportType(TRANSPORT_WIFI);
         ncBuilder.removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED);
-        ncBuilder.setSubscriptionIds(Set.of(0));
+        ncBuilder.setSubscriptionIds(Set.of(TEST_SUBSCRIPTION_ID));
         assertTrue(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
     }
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
index b15c684..e401434 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
@@ -60,7 +60,6 @@
 import com.android.networkstack.apishim.common.UnsupportedApiLevelException
 import com.android.server.connectivity.AutomaticOnOffKeepaliveTracker
 import com.android.server.connectivity.CarrierPrivilegeAuthenticator
-import com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener
 import com.android.server.connectivity.ClatCoordinator
 import com.android.server.connectivity.ConnectivityFlags
 import com.android.server.connectivity.MulticastRoutingCoordinatorService
@@ -73,6 +72,7 @@
 import com.android.testutils.waitForIdle
 import java.util.concurrent.Executors
 import java.util.function.Consumer
+import java.util.function.BiConsumer
 import kotlin.test.assertNull
 import kotlin.test.fail
 import org.junit.After
@@ -222,7 +222,7 @@
                 context: Context,
                 tm: TelephonyManager,
                 requestRestrictedWifiEnabled: Boolean,
-                listener: CarrierPrivilegesLostListener
+                listener: BiConsumer<Int, Int>
         ) = if (SdkLevel.isAtLeastT()) mock<CarrierPrivilegeAuthenticator>() else null
 
         var satelliteNetworkFallbackUidUpdate: Consumer<Set<Int>>? = null