Merge "Store the SubscriptionId together with the Carrier UID" into main
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index e6287bc..3d646fd 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -114,7 +114,6 @@
 import static com.android.net.module.util.PermissionUtils.enforceNetworkStackPermissionOr;
 import static com.android.net.module.util.PermissionUtils.hasAnyPermissionOf;
 import static com.android.server.ConnectivityStatsLog.CONNECTIVITY_STATE_SAMPLE;
-import static com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener;
 import static com.android.server.connectivity.ConnectivityFlags.REQUEST_RESTRICTED_WIFI;
 
 import android.Manifest;
@@ -257,6 +256,7 @@
 import android.stats.connectivity.ValidatedState;
 import android.sysprop.NetworkProperties;
 import android.system.ErrnoException;
+import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
 import android.text.TextUtils;
 import android.util.ArrayMap;
@@ -377,6 +377,7 @@
 import java.util.TreeSet;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 
 /**
@@ -1287,18 +1288,14 @@
     }
     private final LegacyTypeTracker mLegacyTypeTracker = new LegacyTypeTracker(this);
 
-    private final CarrierPrivilegesLostListenerImpl mCarrierPrivilegesLostListenerImpl =
-            new CarrierPrivilegesLostListenerImpl();
-
-    private class CarrierPrivilegesLostListenerImpl implements CarrierPrivilegesLostListener {
-        @Override
-        public void onCarrierPrivilegesLost(int uid) {
-            if (mRequestRestrictedWifiEnabled) {
-                mHandler.sendMessage(mHandler.obtainMessage(
-                        EVENT_UID_CARRIER_PRIVILEGES_LOST, uid, 0 /* arg2 */));
-            }
+    @VisibleForTesting
+    void onCarrierPrivilegesLost(Integer uid, Integer subId) {
+        if (mRequestRestrictedWifiEnabled) {
+            mHandler.sendMessage(mHandler.obtainMessage(
+                    EVENT_UID_CARRIER_PRIVILEGES_LOST, uid, subId));
         }
     }
+
     final LocalPriorityDump mPriorityDumper = new LocalPriorityDump();
     /**
      * Helper class which parses out priority arguments and dumps sections according to their
@@ -1357,11 +1354,6 @@
         }
     }
 
-    @VisibleForTesting
-    CarrierPrivilegesLostListener getCarrierPrivilegesLostListener() {
-        return mCarrierPrivilegesLostListenerImpl;
-    }
-
     /**
      * Dependencies of ConnectivityService, for injection in tests.
      */
@@ -1525,7 +1517,7 @@
                 @NonNull final Context context,
                 @NonNull final TelephonyManager tm,
                 boolean requestRestrictedWifiEnabled,
-                @NonNull CarrierPrivilegesLostListener listener) {
+                @NonNull BiConsumer<Integer, Integer> listener) {
             if (isAtLeastT()) {
                 return new CarrierPrivilegeAuthenticator(
                         context, tm, requestRestrictedWifiEnabled, listener);
@@ -1813,7 +1805,7 @@
                 && mDeps.isFeatureEnabled(context, REQUEST_RESTRICTED_WIFI);
         mCarrierPrivilegeAuthenticator = mDeps.makeCarrierPrivilegeAuthenticator(
                 mContext, mTelephonyManager, mRequestRestrictedWifiEnabled,
-                mCarrierPrivilegesLostListenerImpl);
+                this::onCarrierPrivilegesLost);
 
         if (mDeps.isAtLeastU()
                 && mDeps
@@ -5401,6 +5393,13 @@
         return false;
     }
 
+    private int getSubscriptionIdFromNetworkCaps(@NonNull final NetworkCapabilities caps) {
+        if (mCarrierPrivilegeAuthenticator != null) {
+            return mCarrierPrivilegeAuthenticator.getSubIdFromNetworkCapabilities(caps);
+        }
+        return SubscriptionManager.INVALID_SUBSCRIPTION_ID;
+    }
+
     private void handleRegisterNetworkRequestWithIntent(@NonNull final Message msg) {
         final NetworkRequestInfo nri = (NetworkRequestInfo) (msg.obj);
         // handleRegisterNetworkRequestWithIntent() doesn't apply to multilayer requests.
@@ -6492,7 +6491,7 @@
                     handleFrozenUids(args.mUids, args.mFrozenStates);
                     break;
                 case EVENT_UID_CARRIER_PRIVILEGES_LOST:
-                    handleUidCarrierPrivilegesLost(msg.arg1);
+                    handleUidCarrierPrivilegesLost(msg.arg1, msg.arg2);
                     break;
             }
         }
@@ -9155,7 +9154,7 @@
         }
     }
 
-    private void handleUidCarrierPrivilegesLost(int uid) {
+    private void handleUidCarrierPrivilegesLost(int uid, int subId) {
         ensureRunningOnConnectivityServiceThread();
         // A NetworkRequest needs to be revoked when all the conditions are met
         //   1. It requests restricted network
@@ -9166,6 +9165,7 @@
             if ((nr.isRequest() || nr.isListen())
                     && !nr.hasCapability(NET_CAPABILITY_NOT_RESTRICTED)
                     && nr.getRequestorUid() == uid
+                    && getSubscriptionIdFromNetworkCaps(nr.networkCapabilities) == subId
                     && !hasConnectivityRestrictedNetworksPermission(uid, true)) {
                 declareNetworkRequestUnfulfillable(nr);
             }
@@ -9174,7 +9174,8 @@
         // A NetworkAgent's allowedUids may need to be updated if the app has lost
         // carrier config
         for (final NetworkAgentInfo nai : mNetworkAgentInfos) {
-            if (nai.networkCapabilities.getAllowedUidsNoCopy().contains(uid)) {
+            if (nai.networkCapabilities.getAllowedUidsNoCopy().contains(uid)
+                    && getSubscriptionIdFromNetworkCaps(nai.networkCapabilities) == subId) {
                 final NetworkCapabilities nc = new NetworkCapabilities(nai.networkCapabilities);
                 NetworkAgentInfo.restrictCapabilitiesFromNetworkAgent(
                         nc,
diff --git a/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java b/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
index 533278e..04d0fc1 100644
--- a/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
+++ b/service/src/com/android/server/connectivity/CarrierPrivilegeAuthenticator.java
@@ -40,12 +40,13 @@
 import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
 import android.util.Log;
-import android.util.SparseIntArray;
+import android.util.SparseArray;
 
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.IndentingPrintWriter;
 import com.android.modules.utils.HandlerExecutor;
+import com.android.modules.utils.build.SdkLevel;
 import com.android.net.module.util.DeviceConfigUtils;
 import com.android.networkstack.apishim.TelephonyManagerShimImpl;
 import com.android.networkstack.apishim.common.TelephonyManagerShim;
@@ -55,6 +56,7 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.concurrent.Executor;
+import java.util.function.BiConsumer;
 
 /**
  * Tracks the uid of the carrier privileged app that provides the carrier config.
@@ -71,7 +73,8 @@
     private final TelephonyManagerShim mTelephonyManagerShim;
     private final TelephonyManager mTelephonyManager;
     @GuardedBy("mLock")
-    private final SparseIntArray mCarrierServiceUid = new SparseIntArray(2 /* initialCapacity */);
+    private final SparseArray<CarrierServiceUidWithSubId> mCarrierServiceUidWithSubId =
+            new SparseArray<>(2 /* initialCapacity */);
     @GuardedBy("mLock")
     private int mModemCount = 0;
     private final Object mLock = new Object();
@@ -81,14 +84,14 @@
     private final boolean mUseCallbacksForServiceChanged;
     private final boolean mRequestRestrictedWifiEnabled;
     @NonNull
-    private final CarrierPrivilegesLostListener mListener;
+    private final BiConsumer<Integer, Integer> mListener;
 
     public CarrierPrivilegeAuthenticator(@NonNull final Context c,
             @NonNull final Dependencies deps,
             @NonNull final TelephonyManager t,
             @NonNull final TelephonyManagerShim telephonyManagerShim,
             final boolean requestRestrictedWifiEnabled,
-            @NonNull CarrierPrivilegesLostListener listener) {
+            @NonNull BiConsumer<Integer, Integer> listener) {
         mContext = c;
         mTelephonyManager = t;
         mTelephonyManagerShim = telephonyManagerShim;
@@ -121,7 +124,7 @@
 
     public CarrierPrivilegeAuthenticator(@NonNull final Context c,
             @NonNull final TelephonyManager t, final boolean requestRestrictedWifiEnabled,
-            @NonNull CarrierPrivilegesLostListener listener) {
+            @NonNull BiConsumer<Integer, Integer> listener) {
         this(c, new Dependencies(), t, TelephonyManagerShimImpl.newInstance(t),
                 requestRestrictedWifiEnabled, listener);
     }
@@ -142,18 +145,6 @@
         }
     }
 
-    /**
-     * Listener interface to get a notification when the carrier App lost its privileges.
-     */
-    public interface CarrierPrivilegesLostListener {
-        /**
-         * Called when the carrier App lost its privileges.
-         *
-         * @param uid  The uid of the carrier app which has lost its privileges.
-         */
-        void onCarrierPrivilegesLost(int uid);
-    }
-
     private void simConfigChanged() {
         synchronized (mLock) {
             unregisterCarrierPrivilegesListeners();
@@ -163,6 +154,29 @@
         }
     }
 
+    private static class CarrierServiceUidWithSubId {
+        final int mUid;
+        final int mSubId;
+
+        CarrierServiceUidWithSubId(int uid, int subId) {
+            mUid = uid;
+            mSubId = subId;
+        }
+
+        @Override
+        public boolean equals(Object obj) {
+            if (!(obj instanceof CarrierServiceUidWithSubId)) {
+                return false;
+            }
+            CarrierServiceUidWithSubId compare = (CarrierServiceUidWithSubId) obj;
+            return (mUid == compare.mUid && mSubId == compare.mSubId);
+        }
+
+        @Override
+        public int hashCode() {
+            return mUid * 31 + mSubId;
+        }
+    }
     private class PrivilegeListener implements CarrierPrivilegesListenerShim {
         public final int mLogicalSlot;
 
@@ -192,10 +206,17 @@
                 return;
             }
             synchronized (mLock) {
-                int oldUid = mCarrierServiceUid.get(mLogicalSlot);
-                mCarrierServiceUid.put(mLogicalSlot, carrierServiceUid);
-                if (oldUid != 0 && oldUid != carrierServiceUid) {
-                    mListener.onCarrierPrivilegesLost(oldUid);
+                CarrierServiceUidWithSubId oldPair =
+                        mCarrierServiceUidWithSubId.get(mLogicalSlot);
+                int subId = getSubId(mLogicalSlot);
+                mCarrierServiceUidWithSubId.put(
+                        mLogicalSlot,
+                        new CarrierServiceUidWithSubId(carrierServiceUid, subId));
+                if (oldPair != null
+                        && oldPair.mUid != Process.INVALID_UID
+                        && oldPair.mSubId != SubscriptionManager.INVALID_SUBSCRIPTION_ID
+                        && !oldPair.equals(mCarrierServiceUidWithSubId.get(mLogicalSlot))) {
+                    mListener.accept(oldPair.mUid, oldPair.mSubId);
                 }
             }
         }
@@ -218,10 +239,13 @@
     private void unregisterCarrierPrivilegesListeners() {
         for (PrivilegeListener carrierPrivilegesListener : mCarrierPrivilegesChangedListeners) {
             removeCarrierPrivilegesListener(carrierPrivilegesListener);
-            int oldUid = mCarrierServiceUid.get(carrierPrivilegesListener.mLogicalSlot);
-            mCarrierServiceUid.delete(carrierPrivilegesListener.mLogicalSlot);
-            if (oldUid != 0) {
-                mListener.onCarrierPrivilegesLost(oldUid);
+            CarrierServiceUidWithSubId oldPair =
+                    mCarrierServiceUidWithSubId.get(carrierPrivilegesListener.mLogicalSlot);
+            mCarrierServiceUidWithSubId.remove(carrierPrivilegesListener.mLogicalSlot);
+            if (oldPair != null
+                    && oldPair.mUid != Process.INVALID_UID
+                    && oldPair.mSubId != SubscriptionManager.INVALID_SUBSCRIPTION_ID) {
+                mListener.accept(oldPair.mUid, oldPair.mSubId);
             }
         }
         mCarrierPrivilegesChangedListeners.clear();
@@ -259,7 +283,23 @@
      */
     public boolean isCarrierServiceUidForNetworkCapabilities(int callingUid,
             @NonNull NetworkCapabilities networkCapabilities) {
-        if (callingUid == Process.INVALID_UID) return false;
+        if (callingUid == Process.INVALID_UID) {
+            return false;
+        }
+        int subId = getSubIdFromNetworkCapabilities(networkCapabilities);
+        if (SubscriptionManager.INVALID_SUBSCRIPTION_ID == subId) {
+            return false;
+        }
+        return callingUid == getCarrierServiceUidForSubId(subId);
+    }
+
+    /**
+     * Extract the SubscriptionId from the NetworkCapabilities.
+     *
+     * @param networkCapabilities the network capabilities which may contains the SubscriptionId.
+     * @return the SubscriptionId.
+     */
+    public int getSubIdFromNetworkCapabilities(@NonNull NetworkCapabilities networkCapabilities) {
         int subId;
         if (networkCapabilities.hasSingleTransportBesidesTest(TRANSPORT_CELLULAR)) {
             subId = getSubIdFromTelephonySpecifier(networkCapabilities.getNetworkSpecifier());
@@ -285,21 +325,42 @@
             Log.wtf(TAG, "NetworkCapabilities subIds are inconsistent between "
                     + "specifier/transportInfo and mSubIds : " + networkCapabilities);
         }
-        if (SubscriptionManager.INVALID_SUBSCRIPTION_ID == subId) return false;
-        return callingUid == getCarrierServiceUidForSubId(subId);
+        return subId;
+    }
+
+    @VisibleForTesting
+    protected int getSubId(int slotIndex) {
+        if (SdkLevel.isAtLeastU()) {
+            return SubscriptionManager.getSubscriptionId(slotIndex);
+        } else {
+            SubscriptionManager sm = mContext.getSystemService(SubscriptionManager.class);
+            int[] subIds = sm.getSubscriptionIds(slotIndex);
+            if (subIds != null && subIds.length > 0) {
+                return subIds[0];
+            }
+            return SubscriptionManager.INVALID_SUBSCRIPTION_ID;
+        }
     }
 
     @VisibleForTesting
     void updateCarrierServiceUid() {
         synchronized (mLock) {
-            SparseIntArray oldCarrierServiceUid = mCarrierServiceUid.clone();
-            mCarrierServiceUid.clear();
+            SparseArray<CarrierServiceUidWithSubId> copy = mCarrierServiceUidWithSubId.clone();
+            mCarrierServiceUidWithSubId.clear();
             for (int i = 0; i < mModemCount; i++) {
-                mCarrierServiceUid.put(i, getCarrierServicePackageUidForSlot(i));
+                int subId = getSubId(i);
+                mCarrierServiceUidWithSubId.put(
+                        i,
+                        new CarrierServiceUidWithSubId(
+                                getCarrierServicePackageUidForSlot(i), subId));
             }
-            for (int i = 0; i < oldCarrierServiceUid.size(); i++) {
-                if (mCarrierServiceUid.indexOfValue(oldCarrierServiceUid.valueAt(i)) < 0) {
-                    mListener.onCarrierPrivilegesLost(oldCarrierServiceUid.valueAt(i));
+            for (int i = 0; i < copy.size(); ++i) {
+                CarrierServiceUidWithSubId oldPair = copy.valueAt(i);
+                CarrierServiceUidWithSubId newPair = mCarrierServiceUidWithSubId.get(copy.keyAt(i));
+                if (oldPair.mUid != Process.INVALID_UID
+                        && oldPair.mSubId != SubscriptionManager.INVALID_SUBSCRIPTION_ID
+                        && !oldPair.equals(newPair)) {
+                    mListener.accept(oldPair.mUid, oldPair.mSubId);
                 }
             }
         }
@@ -307,18 +368,17 @@
 
     @VisibleForTesting
     int getCarrierServiceUidForSubId(int subId) {
-        final int slotId = getSlotIndex(subId);
         synchronized (mLock) {
-            return mCarrierServiceUid.get(slotId, Process.INVALID_UID);
+            for (int i = 0; i < mCarrierServiceUidWithSubId.size(); ++i) {
+                if (mCarrierServiceUidWithSubId.valueAt(i).mSubId == subId) {
+                    return mCarrierServiceUidWithSubId.valueAt(i).mUid;
+                }
+            }
+            return Process.INVALID_UID;
         }
     }
 
     @VisibleForTesting
-    protected int getSlotIndex(int subId) {
-        return SubscriptionManager.getSlotIndex(subId);
-    }
-
-    @VisibleForTesting
     int getUidForPackage(String pkgName) {
         if (pkgName == null) {
             return Process.INVALID_UID;
@@ -383,11 +443,12 @@
         pw.println("CarrierPrivilegeAuthenticator:");
         pw.println("mRequestRestrictedWifiEnabled = " + mRequestRestrictedWifiEnabled);
         synchronized (mLock) {
-            final int size = mCarrierServiceUid.size();
-            for (int i = 0; i < size; ++i) {
-                final int logicalSlot = mCarrierServiceUid.keyAt(i);
-                final int serviceUid = mCarrierServiceUid.valueAt(i);
-                pw.println("Logical slot = " + logicalSlot + " : uid = " + serviceUid);
+            for (int i = 0; i < mCarrierServiceUidWithSubId.size(); ++i) {
+                final int logicalSlot = mCarrierServiceUidWithSubId.keyAt(i);
+                final int serviceUid = mCarrierServiceUidWithSubId.valueAt(i).mUid;
+                final int subId = mCarrierServiceUidWithSubId.valueAt(i).mSubId;
+                pw.println("Logical slot = " + logicalSlot + " : uid = " + serviceUid
+                        + " : subId = " + subId);
             }
         }
     }
diff --git a/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt b/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
index 9148770..361d68c 100644
--- a/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
+++ b/tests/integration/src/com/android/server/net/integrationtests/ConnectivityServiceIntegrationTest.kt
@@ -56,7 +56,6 @@
 import com.android.server.NetworkAgentWrapper
 import com.android.server.TestNetIdManager
 import com.android.server.connectivity.CarrierPrivilegeAuthenticator
-import com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener
 import com.android.server.connectivity.ConnectivityResources
 import com.android.server.connectivity.MockableSystemProperties
 import com.android.server.connectivity.MultinetworkPolicyTracker
@@ -89,6 +88,7 @@
 import org.mockito.MockitoAnnotations
 import org.mockito.Spy
 import java.util.function.Consumer
+import java.util.function.BiConsumer
 
 const val SERVICE_BIND_TIMEOUT_MS = 5_000L
 const val TEST_TIMEOUT_MS = 10_000L
@@ -245,7 +245,7 @@
             context: Context,
             tm: TelephonyManager,
             requestRestrictedWifiEnabled: Boolean,
-            listener: CarrierPrivilegesLostListener
+            listener: BiConsumer<Int, Int>
         ): CarrierPrivilegeAuthenticator {
             return CarrierPrivilegeAuthenticator(context,
                 object : CarrierPrivilegeAuthenticator.Dependencies() {
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 6623bbd..c534025 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -173,7 +173,6 @@
 import static com.android.server.ConnectivityServiceTestUtils.transportToLegacyType;
 import static com.android.server.NetworkAgentWrapper.CallbackType.OnQosCallbackRegister;
 import static com.android.server.NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister;
-import static com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener;
 import static com.android.testutils.Cleanup.testAndCleanup;
 import static com.android.testutils.ConcurrentUtils.await;
 import static com.android.testutils.ConcurrentUtils.durationOf;
@@ -488,6 +487,7 @@
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.BiConsumer;
 import java.util.function.Consumer;
 import java.util.function.Predicate;
 import java.util.function.Supplier;
@@ -526,7 +526,7 @@
     // between a LOST callback that arrives immediately and a LOST callback that arrives after
     // the linger/nascent timeout. For this, our assertions should run fast enough to leave
     // less than (mService.mLingerDelayMs - TEST_CALLBACK_TIMEOUT_MS) between the time callbacks are
-    // supposedly fired, and the time we call expectCallback.
+    // supposedly fired, and the time we call expectCapChanged.
     private static final int TEST_CALLBACK_TIMEOUT_MS = 250;
     // Chosen to be less than TEST_CALLBACK_TIMEOUT_MS. This ensures that requests have time to
     // complete before callbacks are verified.
@@ -565,6 +565,7 @@
     private static final int TEST_PACKAGE_UID2 = 321;
     private static final int TEST_PACKAGE_UID3 = 456;
     private static final int NETWORK_ACTIVITY_NO_UID = -1;
+    private static final int TEST_SUBSCRIPTION_ID = 1;
 
     private static final int PACKET_WAKEUP_MARK_MASK = 0x80000000;
 
@@ -2059,7 +2060,7 @@
                 @NonNull final Context context,
                 @NonNull final TelephonyManager tm,
                 final boolean requestRestrictedWifiEnabled,
-                CarrierPrivilegesLostListener listener) {
+                BiConsumer<Integer, Integer> listener) {
             return mDeps.isAtLeastT() ? mCarrierPrivilegeAuthenticator : null;
         }
 
@@ -11486,7 +11487,7 @@
         doTestInterfaceClassActivityChanged(TRANSPORT_CELLULAR);
     }
 
-    private void doTestOnNetworkActive_NewNetworkConnects(int transportType, boolean expectCallback)
+    private void doTestOnNetworkActive_NewNetworkConnects(int transportType, boolean expectCapChanged)
             throws Exception {
         final ConditionVariable onNetworkActiveCv = new ConditionVariable();
         final ConnectivityManager.OnNetworkActiveListener listener = onNetworkActiveCv::open;
@@ -11498,7 +11499,7 @@
         testAndCleanup(() -> {
             mCm.addDefaultNetworkActiveListener(listener);
             agent.connect(true);
-            if (expectCallback) {
+            if (expectCapChanged) {
                 assertTrue(onNetworkActiveCv.block(TEST_CALLBACK_TIMEOUT_MS));
             } else {
                 assertFalse(onNetworkActiveCv.block(TEST_CALLBACK_TIMEOUT_MS));
@@ -11513,7 +11514,7 @@
 
     @Test
     public void testOnNetworkActive_NewCellConnects_CallbackCalled() throws Exception {
-        doTestOnNetworkActive_NewNetworkConnects(TRANSPORT_CELLULAR, true /* expectCallback */);
+        doTestOnNetworkActive_NewNetworkConnects(TRANSPORT_CELLULAR, true /* expectCapChanged */);
     }
 
     @Test
@@ -11522,8 +11523,8 @@
         // networks that tracker adds the idle timer to. And the tracker does not set the idle timer
         // for the ethernet network.
         // So onNetworkActive is not called when the ethernet becomes the default network
-        final boolean expectCallback = mDeps.isAtLeastV();
-        doTestOnNetworkActive_NewNetworkConnects(TRANSPORT_ETHERNET, expectCallback);
+        final boolean expectCapChanged = mDeps.isAtLeastV();
+        doTestOnNetworkActive_NewNetworkConnects(TRANSPORT_ETHERNET, expectCapChanged);
     }
 
     @Test
@@ -17375,7 +17376,7 @@
         return new NetworkRequest.Builder()
             .addTransportType(NetworkCapabilities.TRANSPORT_WIFI)
             .removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED)
-            .setSubscriptionIds(Collections.singleton(Process.myUid()))
+            .setSubscriptionIds(Collections.singleton(TEST_SUBSCRIPTION_ID))
             .build();
     }
 
@@ -17422,32 +17423,46 @@
         final NetworkCallback networkCallback1 = new NetworkCallback();
         final NetworkCallback networkCallback2 = new NetworkCallback();
 
-        mCm.requestNetwork(getRestrictedRequestForWifiWithSubIds(), networkCallback1);
-        mCm.requestNetwork(getRestrictedRequestForWifiWithSubIds(), pendingIntent);
-        mCm.registerNetworkCallback(getRestrictedRequestForWifiWithSubIds(), networkCallback2);
+        mCm.requestNetwork(
+                getRestrictedRequestForWifiWithSubIds(), networkCallback1);
+        mCm.requestNetwork(
+                getRestrictedRequestForWifiWithSubIds(), pendingIntent);
+        mCm.registerNetworkCallback(
+                getRestrictedRequestForWifiWithSubIds(), networkCallback2);
 
         mCm.unregisterNetworkCallback(networkCallback1);
         mCm.releaseNetworkRequest(pendingIntent);
         mCm.unregisterNetworkCallback(networkCallback2);
     }
 
-    @Test
-    @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
-    public void testRestrictedRequestRemovedDueToCarrierPrivilegesLost() throws Exception {
-        mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_DENIED);
-        NetworkCapabilities filter = getRestrictedRequestForWifiWithSubIds().networkCapabilities;
+    private void doTestNetworkRequestWithCarrierPrivilegesLost(
+            boolean shouldGrantRestrictedNetworkPermission,
+            int lostPrivilegeUid,
+            int lostPrivilegeSubId,
+            boolean expectUnavailable,
+            boolean expectCapChanged) throws Exception {
+        if (shouldGrantRestrictedNetworkPermission) {
+            mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_GRANTED);
+        } else {
+            mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_DENIED);
+        }
+
+        NetworkCapabilities filter =
+                getRestrictedRequestForWifiWithSubIds().networkCapabilities;
         final HandlerThread handlerThread = new HandlerThread("testRestrictedFactoryRequests");
         handlerThread.start();
+
         final MockNetworkFactory testFactory = new MockNetworkFactory(handlerThread.getLooper(),
                 mServiceContext, "testFactory", filter, mCsHandlerThread);
         testFactory.register();
-
         testFactory.assertRequestCountEquals(0);
+
         doReturn(true).when(mCarrierPrivilegeAuthenticator)
                 .isCarrierServiceUidForNetworkCapabilities(eq(Process.myUid()), any());
-        final TestNetworkCallback networkCallback1 = new TestNetworkCallback();
-        final NetworkRequest networkrequest1 = getRestrictedRequestForWifiWithSubIds();
-        mCm.requestNetwork(networkrequest1, networkCallback1);
+        final TestNetworkCallback networkCallback = new TestNetworkCallback();
+        final NetworkRequest networkrequest =
+                getRestrictedRequestForWifiWithSubIds();
+        mCm.requestNetwork(networkrequest, networkCallback);
         testFactory.expectRequestAdd();
         testFactory.assertRequestCountEquals(1);
 
@@ -17455,24 +17470,36 @@
                 .setAllowedUids(Set.of(Process.myUid()))
                 .build();
         mWiFiAgent = new TestNetworkAgentWrapper(TRANSPORT_WIFI, new LinkProperties(), nc);
-        mWiFiAgent.connect(true);
-        networkCallback1.expectAvailableThenValidatedCallbacks(mWiFiAgent);
-
+        mWiFiAgent.connect(false);
+        networkCallback.expectAvailableCallbacksUnvalidated(mWiFiAgent);
         final NetworkAgentInfo nai = mService.getNetworkAgentInfoForNetwork(
                 mWiFiAgent.getNetwork());
 
         doReturn(false).when(mCarrierPrivilegeAuthenticator)
                 .isCarrierServiceUidForNetworkCapabilities(eq(Process.myUid()), any());
-        final CarrierPrivilegesLostListener carrierPrivilegesLostListener =
-                mService.getCarrierPrivilegesLostListener();
-        carrierPrivilegesLostListener.onCarrierPrivilegesLost(Process.myUid());
+        doReturn(TEST_SUBSCRIPTION_ID).when(mCarrierPrivilegeAuthenticator)
+                .getSubIdFromNetworkCapabilities(any());
+        mService.onCarrierPrivilegesLost(lostPrivilegeUid, lostPrivilegeSubId);
         waitForIdle();
 
-        testFactory.expectRequestRemove();
-        testFactory.assertRequestCountEquals(0);
-        assertTrue(nai.networkCapabilities.getAllowedUidsNoCopy().isEmpty());
-        networkCallback1.expect(NETWORK_CAPS_UPDATED);
-        networkCallback1.expect(UNAVAILABLE);
+        if (expectCapChanged) {
+            networkCallback.expect(NETWORK_CAPS_UPDATED);
+        }
+        if (expectUnavailable) {
+            networkCallback.expect(UNAVAILABLE);
+        }
+        if (!expectCapChanged && !expectUnavailable) {
+            networkCallback.assertNoCallback();
+        }
+
+        mWiFiAgent.disconnect();
+        waitForIdle();
+
+        if (expectUnavailable) {
+            testFactory.assertRequestCountEquals(0);
+        } else {
+            testFactory.assertRequestCountEquals(1);
+        }
 
         handlerThread.quitSafely();
         handlerThread.join();
@@ -17480,64 +17507,45 @@
 
     @Test
     @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+    public void testRestrictedRequestRemovedDueToCarrierPrivilegesLost() throws Exception {
+        doTestNetworkRequestWithCarrierPrivilegesLost(
+                false /* shouldGrantRestrictedNetworkPermission */,
+                Process.myUid(),
+                TEST_SUBSCRIPTION_ID,
+                true /* expectUnavailable */,
+                true /* expectCapChanged */);
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
+    public void testRequestNotRemoved_MismatchSubId() throws Exception {
+        doTestNetworkRequestWithCarrierPrivilegesLost(
+                false /* shouldGrantRestrictedNetworkPermission */,
+                Process.myUid(),
+                TEST_SUBSCRIPTION_ID + 1,
+                false /* expectUnavailable */,
+                false /* expectCapChanged */);
+    }
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
     public void testRequestNotRemoved_MismatchUid() throws Exception {
-        mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_DENIED);
-        NetworkCapabilities filter = getRestrictedRequestForWifiWithSubIds().networkCapabilities;
-        final HandlerThread handlerThread = new HandlerThread("testRestrictedFactoryRequests");
-        handlerThread.start();
-
-        final MockNetworkFactory testFactory = new MockNetworkFactory(handlerThread.getLooper(),
-                mServiceContext, "testFactory", filter, mCsHandlerThread);
-        testFactory.register();
-
-        doReturn(true).when(mCarrierPrivilegeAuthenticator)
-                .isCarrierServiceUidForNetworkCapabilities(anyInt(), any());
-        final TestNetworkCallback networkCallback1 = new TestNetworkCallback();
-        final NetworkRequest networkrequest1 = getRestrictedRequestForWifiWithSubIds();
-        mCm.requestNetwork(networkrequest1, networkCallback1);
-        testFactory.expectRequestAdd();
-        testFactory.assertRequestCountEquals(1);
-
-        doReturn(false).when(mCarrierPrivilegeAuthenticator)
-                .isCarrierServiceUidForNetworkCapabilities(eq(Process.myUid()), any());
-        final CarrierPrivilegesLostListener carrierPrivilegesLostListener =
-                mService.getCarrierPrivilegesLostListener();
-        carrierPrivilegesLostListener.onCarrierPrivilegesLost(Process.myUid() + 1);
-        expectNoRequestChanged(testFactory);
-
-        handlerThread.quitSafely();
-        handlerThread.join();
+        doTestNetworkRequestWithCarrierPrivilegesLost(
+                false /* shouldGrantRestrictedNetworkPermission */,
+                Process.myUid() + 1,
+                TEST_SUBSCRIPTION_ID,
+                false /* expectUnavailable */,
+                false /* expectCapChanged */);
     }
 
     @Test
     @IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
     public void testRequestNotRemoved_HasRestrictedNetworkPermission() throws Exception {
-        mServiceContext.setPermission(CONNECTIVITY_USE_RESTRICTED_NETWORKS, PERMISSION_GRANTED);
-        NetworkCapabilities filter = getRestrictedRequestForWifiWithSubIds().networkCapabilities;
-        final HandlerThread handlerThread = new HandlerThread("testRestrictedFactoryRequests");
-        handlerThread.start();
-
-        final MockNetworkFactory testFactory = new MockNetworkFactory(handlerThread.getLooper(),
-                mServiceContext, "testFactory", filter, mCsHandlerThread);
-        testFactory.register();
-
-        doReturn(true).when(mCarrierPrivilegeAuthenticator)
-            .isCarrierServiceUidForNetworkCapabilities(anyInt(), any());
-        final TestNetworkCallback networkCallback1 = new TestNetworkCallback();
-        final NetworkRequest networkrequest1 = getRestrictedRequestForWifiWithSubIds();
-        mCm.requestNetwork(networkrequest1, networkCallback1);
-        testFactory.expectRequestAdd();
-        testFactory.assertRequestCountEquals(1);
-
-        doReturn(false).when(mCarrierPrivilegeAuthenticator)
-                .isCarrierServiceUidForNetworkCapabilities(eq(Process.myUid()), any());
-        final CarrierPrivilegesLostListener carrierPrivilegesLostListener =
-                mService.getCarrierPrivilegesLostListener();
-        carrierPrivilegesLostListener.onCarrierPrivilegesLost(Process.myUid());
-        expectNoRequestChanged(testFactory);
-
-        handlerThread.quitSafely();
-        handlerThread.join();
+        doTestNetworkRequestWithCarrierPrivilegesLost(
+                true /* shouldGrantRestrictedNetworkPermission */,
+                Process.myUid(),
+                TEST_SUBSCRIPTION_ID,
+                false /* expectUnavailable */,
+                true /* expectCapChanged */);
     }
     @Test
     public void testAllowedUids() throws Exception {
diff --git a/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java b/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
index 9f0ec30..7bd2b56 100644
--- a/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
+++ b/tests/unit/java/com/android/server/connectivity/CarrierPrivilegeAuthenticatorTest.java
@@ -20,7 +20,6 @@
 import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
 import static android.telephony.TelephonyManager.ACTION_MULTI_SIM_CONFIG_CHANGED;
 
-import static com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener;
 import static com.android.server.connectivity.ConnectivityFlags.CARRIER_SERVICE_CHANGED_USE_CALLBACK;
 
 import static org.junit.Assert.assertEquals;
@@ -47,7 +46,6 @@
 import android.net.TelephonyNetworkSpecifier;
 import android.os.Build;
 import android.os.HandlerThread;
-import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
 
 import com.android.net.module.util.CollectionUtils;
@@ -71,6 +69,7 @@
 import java.util.Collections;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.BiConsumer;
 
 /**
  * Tests for CarrierPrivilegeAuthenticatorTest.
@@ -92,7 +91,7 @@
     @NonNull private final TelephonyManagerShimImpl mTelephonyManagerShim;
     @NonNull private final PackageManager mPackageManager;
     @NonNull private TestCarrierPrivilegeAuthenticator mCarrierPrivilegeAuthenticator;
-    @NonNull private final CarrierPrivilegesLostListener mListener;
+    @NonNull private final BiConsumer<Integer, Integer> mListener;
     private final int mCarrierConfigPkgUid = 12345;
     private final boolean mUseCallbacks;
     private final String mTestPkg = "com.android.server.connectivity.test";
@@ -107,9 +106,8 @@
                     mListener);
         }
         @Override
-        protected int getSlotIndex(int subId) {
-            if (SubscriptionManager.DEFAULT_SUBSCRIPTION_ID == subId) return TEST_SUBSCRIPTION_ID;
-            return subId;
+        protected int getSubId(int slotIndex) {
+            return TEST_SUBSCRIPTION_ID;
         }
     }
 
@@ -129,7 +127,7 @@
         mTelephonyManager = mock(TelephonyManager.class);
         mTelephonyManagerShim = mock(TelephonyManagerShimImpl.class);
         mPackageManager = mock(PackageManager.class);
-        mListener = mock(CarrierPrivilegesLostListener.class);
+        mListener = mock(BiConsumer.class);
         mHandlerThread = new HandlerThread(CarrierPrivilegeAuthenticatorTest.class.getSimpleName());
         mUseCallbacks = useCallbacks;
         final Dependencies deps = mock(Dependencies.class);
@@ -184,7 +182,7 @@
 
         final NetworkCapabilities.Builder ncBuilder = new NetworkCapabilities.Builder()
                 .addTransportType(TRANSPORT_CELLULAR)
-                .setNetworkSpecifier(new TelephonyNetworkSpecifier(0));
+                .setNetworkSpecifier(new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID));
 
         assertTrue(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
@@ -220,7 +218,8 @@
 
         newListeners.get(0).onCarrierServiceChanged(null, mCarrierConfigPkgUid);
 
-        final TelephonyNetworkSpecifier specifier = new TelephonyNetworkSpecifier(0);
+        final TelephonyNetworkSpecifier specifier =
+                new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID);
         final NetworkCapabilities nc = new NetworkCapabilities.Builder()
                 .addTransportType(TRANSPORT_CELLULAR)
                 .setNetworkSpecifier(specifier)
@@ -239,7 +238,11 @@
         l.onCarrierServiceChanged(null, mCarrierConfigPkgUid);
         l.onCarrierServiceChanged(null, mCarrierConfigPkgUid + 1);
         if (mUseCallbacks) {
-            verify(mListener).onCarrierPrivilegesLost(eq(mCarrierConfigPkgUid));
+            verify(mListener).accept(eq(mCarrierConfigPkgUid), eq(TEST_SUBSCRIPTION_ID));
+        }
+        l.onCarrierServiceChanged(null, mCarrierConfigPkgUid + 2);
+        if (mUseCallbacks) {
+            verify(mListener).accept(eq(mCarrierConfigPkgUid + 1), eq(TEST_SUBSCRIPTION_ID));
         }
     }
 
@@ -247,7 +250,8 @@
     public void testOnCarrierPrivilegesChanged() throws Exception {
         final CarrierPrivilegesListenerShim listener = getCarrierPrivilegesListeners().get(0);
 
-        final TelephonyNetworkSpecifier specifier = new TelephonyNetworkSpecifier(0);
+        final TelephonyNetworkSpecifier specifier =
+                new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID);
         final NetworkCapabilities nc = new NetworkCapabilities.Builder()
                 .addTransportType(TRANSPORT_CELLULAR)
                 .setNetworkSpecifier(specifier)
@@ -275,7 +279,7 @@
         assertFalse(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
 
-        ncBuilder.setNetworkSpecifier(new TelephonyNetworkSpecifier(0));
+        ncBuilder.setNetworkSpecifier(new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID));
         assertTrue(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
 
@@ -284,7 +288,7 @@
         ncBuilder.setNetworkSpecifier(null);
         ncBuilder.removeTransportType(TRANSPORT_CELLULAR);
         ncBuilder.addTransportType(TRANSPORT_WIFI);
-        ncBuilder.setNetworkSpecifier(new TelephonyNetworkSpecifier(0));
+        ncBuilder.setNetworkSpecifier(new TelephonyNetworkSpecifier(TEST_SUBSCRIPTION_ID));
         assertFalse(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
     }
@@ -298,7 +302,7 @@
         final NetworkCapabilities.Builder ncBuilder = new NetworkCapabilities.Builder();
         ncBuilder.addTransportType(TRANSPORT_WIFI);
         ncBuilder.removeCapability(NetworkCapabilities.NET_CAPABILITY_NOT_RESTRICTED);
-        ncBuilder.setSubscriptionIds(Set.of(0));
+        ncBuilder.setSubscriptionIds(Set.of(TEST_SUBSCRIPTION_ID));
         assertTrue(mCarrierPrivilegeAuthenticator.isCarrierServiceUidForNetworkCapabilities(
                 mCarrierConfigPkgUid, ncBuilder.build()));
     }
diff --git a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
index b15c684..e401434 100644
--- a/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
+++ b/tests/unit/java/com/android/server/connectivityservice/base/CSTest.kt
@@ -60,7 +60,6 @@
 import com.android.networkstack.apishim.common.UnsupportedApiLevelException
 import com.android.server.connectivity.AutomaticOnOffKeepaliveTracker
 import com.android.server.connectivity.CarrierPrivilegeAuthenticator
-import com.android.server.connectivity.CarrierPrivilegeAuthenticator.CarrierPrivilegesLostListener
 import com.android.server.connectivity.ClatCoordinator
 import com.android.server.connectivity.ConnectivityFlags
 import com.android.server.connectivity.MulticastRoutingCoordinatorService
@@ -73,6 +72,7 @@
 import com.android.testutils.waitForIdle
 import java.util.concurrent.Executors
 import java.util.function.Consumer
+import java.util.function.BiConsumer
 import kotlin.test.assertNull
 import kotlin.test.fail
 import org.junit.After
@@ -222,7 +222,7 @@
                 context: Context,
                 tm: TelephonyManager,
                 requestRestrictedWifiEnabled: Boolean,
-                listener: CarrierPrivilegesLostListener
+                listener: BiConsumer<Int, Int>
         ) = if (SdkLevel.isAtLeastT()) mock<CarrierPrivilegeAuthenticator>() else null
 
         var satelliteNetworkFallbackUidUpdate: Consumer<Set<Int>>? = null