Merge "Add connectivity pre-check for net permission cts tests"
diff --git a/framework/src/android/net/NattKeepalivePacketData.java b/framework/src/android/net/NattKeepalivePacketData.java
index a18e713..a524859 100644
--- a/framework/src/android/net/NattKeepalivePacketData.java
+++ b/framework/src/android/net/NattKeepalivePacketData.java
@@ -29,7 +29,9 @@
 import com.android.net.module.util.IpUtils;
 
 import java.net.Inet4Address;
+import java.net.Inet6Address;
 import java.net.InetAddress;
+import java.net.UnknownHostException;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.util.Objects;
@@ -38,6 +40,7 @@
 @SystemApi
 public final class NattKeepalivePacketData extends KeepalivePacketData implements Parcelable {
     private static final int IPV4_HEADER_LENGTH = 20;
+    private static final int IPV6_HEADER_LENGTH = 40;
     private static final int UDP_HEADER_LENGTH = 8;
 
     // This should only be constructed via static factory methods, such as
@@ -59,13 +62,25 @@
             throw new InvalidPacketException(ERROR_INVALID_PORT);
         }
 
-        if (!(srcAddress instanceof Inet4Address) || !(dstAddress instanceof Inet4Address)) {
+        // Convert IPv4 mapped v6 address to v4 if any.
+        final InetAddress srcAddr, dstAddr;
+        try {
+            srcAddr = InetAddress.getByAddress(srcAddress.getAddress());
+            dstAddr = InetAddress.getByAddress(dstAddress.getAddress());
+        } catch (UnknownHostException e) {
             throw new InvalidPacketException(ERROR_INVALID_IP_ADDRESS);
         }
 
-        return nattKeepalivePacketv4(
-                (Inet4Address) srcAddress, srcPort,
-                (Inet4Address) dstAddress, dstPort);
+        if (srcAddr instanceof Inet4Address && dstAddr instanceof Inet4Address) {
+            return nattKeepalivePacketv4(
+                    (Inet4Address) srcAddr, srcPort, (Inet4Address) dstAddr, dstPort);
+        } else if (srcAddr instanceof Inet6Address && dstAddress instanceof Inet6Address) {
+            return nattKeepalivePacketv6(
+                    (Inet6Address) srcAddr, srcPort, (Inet6Address) dstAddr, dstPort);
+        } else {
+            // Destination address and source address should be the same IP family.
+            throw new InvalidPacketException(ERROR_INVALID_IP_ADDRESS);
+        }
     }
 
     private static NattKeepalivePacketData nattKeepalivePacketv4(
@@ -82,14 +97,14 @@
         // /proc/sys/net/ipv4/ip_default_ttl. Use hard-coded 64 for simplicity.
         buf.put((byte) 64);                                 // TTL
         buf.put((byte) OsConstants.IPPROTO_UDP);
-        int ipChecksumOffset = buf.position();
+        final int ipChecksumOffset = buf.position();
         buf.putShort((short) 0);                            // IP checksum
         buf.put(srcAddress.getAddress());
         buf.put(dstAddress.getAddress());
         buf.putShort((short) srcPort);
         buf.putShort((short) dstPort);
         buf.putShort((short) (UDP_HEADER_LENGTH + 1));      // UDP length
-        int udpChecksumOffset = buf.position();
+        final int udpChecksumOffset = buf.position();
         buf.putShort((short) 0);                            // UDP checksum
         buf.put((byte) 0xff);                               // NAT-T keepalive
         buf.putShort(ipChecksumOffset, IpUtils.ipChecksum(buf, 0));
@@ -98,6 +113,30 @@
         return new NattKeepalivePacketData(srcAddress, srcPort, dstAddress, dstPort, buf.array());
     }
 
+    private static NattKeepalivePacketData nattKeepalivePacketv6(
+            Inet6Address srcAddress, int srcPort, Inet6Address dstAddress, int dstPort)
+            throws InvalidPacketException {
+        final ByteBuffer buf = ByteBuffer.allocate(IPV6_HEADER_LENGTH + UDP_HEADER_LENGTH + 1);
+        buf.order(ByteOrder.BIG_ENDIAN);
+        buf.putInt(0x60000000);                         // IP version, traffic class and flow label
+        buf.putShort((short) (UDP_HEADER_LENGTH + 1));  // Payload length
+        buf.put((byte) OsConstants.IPPROTO_UDP);        // Next header
+        // For native ipv6, this hop limit value should use the per interface v6 hoplimit sysctl.
+        // For 464xlat, this value should use the v4 ttl sysctl.
+        // Either way, for simplicity, just hard code 64.
+        buf.put((byte) 64);                             // Hop limit
+        buf.put(srcAddress.getAddress());
+        buf.put(dstAddress.getAddress());
+        // UDP
+        buf.putShort((short) srcPort);
+        buf.putShort((short) dstPort);
+        buf.putShort((short) (UDP_HEADER_LENGTH + 1));  // UDP length = Payload length
+        final int udpChecksumOffset = buf.position();
+        buf.putShort((short) 0);                        // UDP checksum
+        buf.put((byte) 0xff);                           // NAT-T keepalive. 1 byte of data
+        buf.putShort(udpChecksumOffset, IpUtils.udpChecksum(buf, 0, IPV6_HEADER_LENGTH));
+        return new NattKeepalivePacketData(srcAddress, srcPort, dstAddress, dstPort, buf.array());
+    }
     /** Parcelable Implementation */
     public int describeContents() {
         return 0;
diff --git a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
index 981bd4d..6ba2033 100644
--- a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
@@ -94,6 +94,7 @@
     private static final int ADJUST_TCP_POLLING_DELAY_MS = 2000;
     private static final String AUTOMATIC_ON_OFF_KEEPALIVE_VERSION =
             "automatic_on_off_keepalive_version";
+    public static final long METRICS_COLLECTION_DURATION_MS = 24 * 60 * 60 * 1_000L;
 
     // ConnectivityService parses message constants from itself and AutomaticOnOffKeepaliveTracker
     // with MessageUtils for debugging purposes, and crashes if some messages have the same values.
@@ -180,6 +181,9 @@
     private final LocalLog mEventLog = new LocalLog(MAX_EVENTS_LOGS);
 
     private final KeepaliveStatsTracker mKeepaliveStatsTracker;
+
+    private final long mMetricsWriteTimeBase;
+
     /**
      * Information about a managed keepalive.
      *
@@ -311,7 +315,26 @@
                 mContext, mConnectivityServiceHandler);
 
         mAlarmManager = mDependencies.getAlarmManager(context);
-        mKeepaliveStatsTracker = new KeepaliveStatsTracker(context, handler);
+        mKeepaliveStatsTracker =
+                mDependencies.newKeepaliveStatsTracker(context, handler);
+
+        final long time = mDependencies.getElapsedRealtime();
+        mMetricsWriteTimeBase = time % METRICS_COLLECTION_DURATION_MS;
+        final long triggerAtMillis = mMetricsWriteTimeBase + METRICS_COLLECTION_DURATION_MS;
+        mAlarmManager.set(AlarmManager.ELAPSED_REALTIME_WAKEUP, triggerAtMillis, TAG,
+                this::writeMetricsAndRescheduleAlarm, handler);
+    }
+
+    private void writeMetricsAndRescheduleAlarm() {
+        mKeepaliveStatsTracker.writeAndResetMetrics();
+
+        final long time = mDependencies.getElapsedRealtime();
+        final long triggerAtMillis =
+                mMetricsWriteTimeBase
+                        + (time - time % METRICS_COLLECTION_DURATION_MS)
+                        + METRICS_COLLECTION_DURATION_MS;
+        mAlarmManager.set(AlarmManager.ELAPSED_REALTIME_WAKEUP, triggerAtMillis, TAG,
+                this::writeMetricsAndRescheduleAlarm, mConnectivityServiceHandler);
     }
 
     private void startTcpPollingAlarm(@NonNull AutomaticOnOffKeepalive ki) {
@@ -898,6 +921,14 @@
         }
 
         /**
+         * Construct a new KeepaliveStatsTracker.
+         */
+        public KeepaliveStatsTracker newKeepaliveStatsTracker(@NonNull Context context,
+                @NonNull Handler connectivityserviceHander) {
+            return new KeepaliveStatsTracker(context, connectivityserviceHander);
+        }
+
+        /**
          * Find out if a feature is enabled from DeviceConfig.
          *
          * @param name The name of the property to look up.
diff --git a/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java b/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
index ff9bb70..d59d526 100644
--- a/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
@@ -19,7 +19,10 @@
 import static android.telephony.SubscriptionManager.OnSubscriptionsChangedListener;
 
 import android.annotation.NonNull;
+import android.content.BroadcastReceiver;
 import android.content.Context;
+import android.content.Intent;
+import android.content.IntentFilter;
 import android.net.Network;
 import android.net.NetworkCapabilities;
 import android.net.NetworkSpecifier;
@@ -42,6 +45,8 @@
 import com.android.metrics.KeepaliveLifetimeForCarrier;
 import com.android.metrics.KeepaliveLifetimePerCarrier;
 import com.android.modules.utils.BackgroundThread;
+import com.android.net.module.util.CollectionUtils;
+import com.android.server.ConnectivityStatsLog;
 
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -67,9 +72,7 @@
     // Mapping of subId to carrierId. Updates are received from OnSubscriptionsChangedListener
     private final SparseIntArray mCachedCarrierIdPerSubId = new SparseIntArray();
     // The default subscription id obtained from SubscriptionManager.getDefaultSubscriptionId.
-    // Updates are done from the OnSubscriptionsChangedListener. Note that there is no callback done
-    // to OnSubscriptionsChangedListener when the default sub id changes.
-    // TODO: Register a listener for the default subId when it is possible.
+    // Updates are received from the ACTION_DEFAULT_SUBSCRIPTION_CHANGED broadcast.
     private int mCachedDefaultSubscriptionId = SubscriptionManager.INVALID_SUBSCRIPTION_ID;
 
     // Class to store network information, lifetime durations and active state of a keepalive.
@@ -267,6 +270,19 @@
                 Objects.requireNonNull(context.getSystemService(SubscriptionManager.class));
 
         mLastUpdateDurationsTimestamp = mDependencies.getElapsedRealtime();
+        context.registerReceiver(
+                new BroadcastReceiver() {
+                    @Override
+                    public void onReceive(Context context, Intent intent) {
+                        mCachedDefaultSubscriptionId =
+                                intent.getIntExtra(
+                                        SubscriptionManager.EXTRA_SUBSCRIPTION_INDEX,
+                                        SubscriptionManager.INVALID_SUBSCRIPTION_ID);
+                    }
+                },
+                new IntentFilter(SubscriptionManager.ACTION_DEFAULT_SUBSCRIPTION_CHANGED),
+                /* broadcastPermission= */ null,
+                mConnectivityServiceHandler);
 
         // The default constructor for OnSubscriptionsChangedListener will always implicitly grab
         // the looper of the current thread. In the case the current thread does not have a looper,
@@ -284,11 +300,8 @@
                                 // but not necessarily empty, simply ignore it. Another call to the
                                 // listener will be invoked in the future.
                                 if (activeSubInfoList == null) return;
-                                final int defaultSubId =
-                                        subscriptionManager.getDefaultSubscriptionId();
                                 mConnectivityServiceHandler.post(() -> {
                                     mCachedCarrierIdPerSubId.clear();
-                                    mCachedDefaultSubscriptionId = defaultSubId;
 
                                     for (final SubscriptionInfo subInfo : activeSubInfoList) {
                                         mCachedCarrierIdPerSubId.put(subInfo.getSubscriptionId(),
@@ -592,6 +605,7 @@
      *
      * @return the DailykeepaliveInfoReported proto that was built.
      */
+    @VisibleForTesting
     public @NonNull DailykeepaliveInfoReported buildAndResetMetrics() {
         ensureRunningOnHandlerThread();
         final long timeNow = mDependencies.getElapsedRealtime();
@@ -620,6 +634,20 @@
         return metrics;
     }
 
+    /** Writes the stored metrics to ConnectivityStatsLog and resets.  */
+    public void writeAndResetMetrics() {
+        ensureRunningOnHandlerThread();
+        final DailykeepaliveInfoReported dailyKeepaliveInfoReported = buildAndResetMetrics();
+        ConnectivityStatsLog.write(
+                ConnectivityStatsLog.DAILY_KEEPALIVE_INFO_REPORTED,
+                dailyKeepaliveInfoReported.getDurationPerNumOfKeepalive().toByteArray(),
+                dailyKeepaliveInfoReported.getKeepaliveLifetimePerCarrier().toByteArray(),
+                dailyKeepaliveInfoReported.getKeepaliveRequests(),
+                dailyKeepaliveInfoReported.getAutomaticKeepaliveRequests(),
+                dailyKeepaliveInfoReported.getDistinctUserCount(),
+                CollectionUtils.toIntArray(dailyKeepaliveInfoReported.getUidList()));
+    }
+
     private void ensureRunningOnHandlerThread() {
         if (mConnectivityServiceHandler.getLooper().getThread() != Thread.currentThread()) {
             throw new IllegalStateException(
diff --git a/tests/common/java/android/net/NattKeepalivePacketDataTest.kt b/tests/common/java/android/net/NattKeepalivePacketDataTest.kt
index 36b9a91..1f04fb8 100644
--- a/tests/common/java/android/net/NattKeepalivePacketDataTest.kt
+++ b/tests/common/java/android/net/NattKeepalivePacketDataTest.kt
@@ -46,6 +46,8 @@
     private val TEST_SRC_ADDRV4 = "198.168.0.2".address()
     private val TEST_DST_ADDRV4 = "198.168.0.1".address()
     private val TEST_ADDRV6 = "2001:db8::1".address()
+    private val TEST_ADDRV4MAPPEDV6 = "::ffff:1.2.3.4".address()
+    private val TEST_ADDRV4 = "1.2.3.4".address()
 
     private fun String.address() = InetAddresses.parseNumericAddress(this)
     private fun nattKeepalivePacket(
@@ -83,6 +85,28 @@
         }
     }
 
+    @Test @IgnoreUpTo(Build.VERSION_CODES.R)
+    fun testConstructor_afterR() {
+        // v4 mapped v6 will be translated to a v4 address.
+        assertFailsWith<InvalidPacketException> {
+            nattKeepalivePacket(srcAddress = TEST_ADDRV6, dstAddress = TEST_ADDRV4MAPPEDV6)
+        }
+        assertFailsWith<InvalidPacketException> {
+            nattKeepalivePacket(srcAddress = TEST_ADDRV4MAPPEDV6, dstAddress = TEST_ADDRV6)
+        }
+
+        // Both src and dst address will be v4 after translation, so it won't cause exception.
+        val packet1 = nattKeepalivePacket(
+            dstAddress = TEST_ADDRV4MAPPEDV6, srcAddress = TEST_ADDRV4MAPPEDV6)
+        assertEquals(TEST_ADDRV4, packet1.srcAddress)
+        assertEquals(TEST_ADDRV4, packet1.dstAddress)
+
+        // Packet with v6 src and v6 dst address is valid.
+        val packet2 = nattKeepalivePacket(srcAddress = TEST_ADDRV6, dstAddress = TEST_ADDRV6)
+        assertEquals(TEST_ADDRV6, packet2.srcAddress)
+        assertEquals(TEST_ADDRV6, packet2.dstAddress)
+    }
+
     @Test @IgnoreUpTo(Build.VERSION_CODES.Q)
     fun testParcel() {
         assertParcelingIsLossless(nattKeepalivePacket())
diff --git a/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java b/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
index ce789fc..21f1358 100644
--- a/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
+++ b/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
@@ -86,7 +86,7 @@
     private static final String FEATURE_IPSEC_TUNNEL_MIGRATION =
             "android.software.ipsec_tunnel_migration";
 
-    private static final int SOCKET_TIMEOUT_MS = 2000;
+    private static final int SOCKET_TIMEOUT_MS = 10_000;
     private static final int PRIVATE_DNS_PROBE_MS = 1_000;
 
     private static final int PRIVATE_DNS_SETTING_TIMEOUT_MS = 10_000;
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 57c3acc..8f3e097 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -6856,10 +6856,6 @@
         ka = mCm.startNattKeepalive(myNet, validKaInterval, callback, myIPv6, 1234, dstIPv4);
         callback.expectError(PacketKeepalive.ERROR_INVALID_IP_ADDRESS);
 
-        // NAT-T is only supported for IPv4.
-        ka = mCm.startNattKeepalive(myNet, validKaInterval, callback, myIPv6, 1234, dstIPv6);
-        callback.expectError(PacketKeepalive.ERROR_INVALID_IP_ADDRESS);
-
         ka = mCm.startNattKeepalive(myNet, validKaInterval, callback, myIPv4, 123456, dstIPv4);
         callback.expectError(PacketKeepalive.ERROR_INVALID_PORT);
 
@@ -7010,13 +7006,6 @@
             callback.expectError(SocketKeepalive.ERROR_INVALID_IP_ADDRESS);
         }
 
-        // NAT-T is only supported for IPv4.
-        try (SocketKeepalive ka = mCm.createSocketKeepalive(
-                myNet, testSocket, myIPv6, dstIPv6, executor, callback)) {
-            ka.start(validKaInterval);
-            callback.expectError(SocketKeepalive.ERROR_INVALID_IP_ADDRESS);
-        }
-
         // Basic check before testing started keepalive.
         try (SocketKeepalive ka = mCm.createSocketKeepalive(
                 myNet, testSocket, myIPv4, dstIPv4, executor, callback)) {
diff --git a/tests/unit/java/com/android/server/VpnManagerServiceTest.java b/tests/unit/java/com/android/server/VpnManagerServiceTest.java
index deb56ef..bf23cd1 100644
--- a/tests/unit/java/com/android/server/VpnManagerServiceTest.java
+++ b/tests/unit/java/com/android/server/VpnManagerServiceTest.java
@@ -75,12 +75,15 @@
 @IgnoreUpTo(R) // VpnManagerService is not available before R
 @SmallTest
 public class VpnManagerServiceTest extends VpnTestBase {
+    private static final String CONTEXT_ATTRIBUTION_TAG = "VPN_MANAGER";
+
     @Rule
     public final DevSdkIgnoreRule mIgnoreRule = new DevSdkIgnoreRule();
 
     private static final int TIMEOUT_MS = 2_000;
 
     @Mock Context mContext;
+    @Mock Context mContextWithoutAttributionTag;
     @Mock Context mSystemContext;
     @Mock Context mUserAllContext;
     private HandlerThread mHandlerThread;
@@ -144,6 +147,13 @@
 
         mHandlerThread = new HandlerThread("TestVpnManagerService");
         mDeps = new VpnManagerServiceDependencies();
+
+        // The attribution tag is a dependency for IKE library to collect VPN metrics correctly
+        // and thus should not be changed without updating the IKE code.
+        doReturn(mContext)
+                .when(mContextWithoutAttributionTag)
+                .createAttributionContext(CONTEXT_ATTRIBUTION_TAG);
+
         doReturn(mUserAllContext).when(mContext).createContextAsUser(UserHandle.ALL, 0);
         doReturn(mSystemContext).when(mContext).createContextAsUser(UserHandle.SYSTEM, 0);
         doReturn(mPackageManager).when(mContext).getPackageManager();
@@ -153,7 +163,7 @@
         mockService(mContext, UserManager.class, Context.USER_SERVICE, mUserManager);
         doReturn(SYSTEM_USER).when(mUserManager).getUserInfo(eq(SYSTEM_USER_ID));
 
-        mService = new VpnManagerService(mContext, mDeps);
+        mService = new VpnManagerService(mContextWithoutAttributionTag, mDeps);
         mService.systemReady();
 
         final ArgumentCaptor<BroadcastReceiver> intentReceiverCaptor =
diff --git a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
index 0aecd64..8232658 100644
--- a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
@@ -21,6 +21,7 @@
 import static android.net.NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE;
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 
+import static com.android.server.connectivity.AutomaticOnOffKeepaliveTracker.METRICS_COLLECTION_DURATION_MS;
 import static com.android.testutils.HandlerUtils.visibleOnHandlerThread;
 
 import static org.junit.Assert.assertEquals;
@@ -42,6 +43,7 @@
 import static org.mockito.Mockito.ignoreStubs;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 
@@ -125,6 +127,7 @@
     @Mock NetworkAgentInfo mNai;
     @Mock SubscriptionManager mSubscriptionManager;
 
+    KeepaliveStatsTracker mKeepaliveStatsTracker;
     TestKeepaliveTracker mKeepaliveTracker;
     AOOTestHandler mTestHandler;
     TestTcpKeepaliveController mTcpController;
@@ -344,8 +347,14 @@
         mTestHandler = new AOOTestHandler(mHandlerThread.getLooper());
         mTcpController = new TestTcpKeepaliveController(mTestHandler);
         mKeepaliveTracker = new TestKeepaliveTracker(mCtx, mTestHandler, mTcpController);
+        mKeepaliveStatsTracker = spy(new KeepaliveStatsTracker(mCtx, mTestHandler));
         doReturn(mKeepaliveTracker).when(mDependencies).newKeepaliveTracker(mCtx, mTestHandler);
+        doReturn(mKeepaliveStatsTracker)
+                .when(mDependencies)
+                .newKeepaliveStatsTracker(mCtx, mTestHandler);
+
         doReturn(true).when(mDependencies).isFeatureEnabled(any(), anyBoolean());
+        doReturn(0L).when(mDependencies).getElapsedRealtime();
         mAOOKeepaliveTracker =
                 new AutomaticOnOffKeepaliveTracker(mCtx, mTestHandler, mDependencies);
     }
@@ -499,6 +508,30 @@
         assertEquals(testInfo.underpinnedNetwork, mTestHandler.mLastAutoKi.getUnderpinnedNetwork());
     }
 
+    @Test
+    public void testAlarm_writeMetrics() throws Exception {
+        final ArgumentCaptor<AlarmManager.OnAlarmListener> listenerCaptor =
+                ArgumentCaptor.forClass(AlarmManager.OnAlarmListener.class);
+
+        // First AlarmManager.set call from the constructor.
+        verify(mAlarmManager).set(eq(AlarmManager.ELAPSED_REALTIME_WAKEUP),
+                eq(METRICS_COLLECTION_DURATION_MS), any() /* tag */, listenerCaptor.capture(),
+                eq(mTestHandler));
+
+        final AlarmManager.OnAlarmListener listener = listenerCaptor.getValue();
+
+        doReturn(METRICS_COLLECTION_DURATION_MS).when(mDependencies).getElapsedRealtime();
+        // For realism, the listener should be posted on the handler
+        mTestHandler.post(() -> listener.onAlarm());
+        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+
+        verify(mKeepaliveStatsTracker).writeAndResetMetrics();
+        // Alarm is rescheduled.
+        verify(mAlarmManager).set(eq(AlarmManager.ELAPSED_REALTIME_WAKEUP),
+                eq(METRICS_COLLECTION_DURATION_MS * 2),
+                any() /* tag */, listenerCaptor.capture(), eq(mTestHandler));
+    }
+
     private void setupResponseWithSocketExisting() throws Exception {
         final ByteBuffer tcpBufferV6 = getByteBuffer(TEST_RESPONSE_BYTES);
         final ByteBuffer tcpBufferV4 = getByteBuffer(TEST_RESPONSE_BYTES);
diff --git a/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java b/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java
index 9472ded..0d2e540 100644
--- a/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java
@@ -27,12 +27,15 @@
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 
+import android.content.BroadcastReceiver;
 import android.content.Context;
+import android.content.Intent;
 import android.net.Network;
 import android.net.NetworkCapabilities;
 import android.net.TelephonyNetworkSpecifier;
@@ -109,6 +112,18 @@
     @Mock private KeepaliveStatsTracker.Dependencies mDependencies;
     @Mock private SubscriptionManager mSubscriptionManager;
 
+    private void triggerBroadcastDefaultSubId(int subId) {
+        final ArgumentCaptor<BroadcastReceiver> receiverCaptor =
+                ArgumentCaptor.forClass(BroadcastReceiver.class);
+        verify(mContext).registerReceiver(receiverCaptor.capture(), /* filter= */ any(),
+                /* broadcastPermission= */ any(), eq(mTestHandler));
+        final Intent intent =
+                new Intent(TelephonyManager.ACTION_SUBSCRIPTION_CARRIER_IDENTITY_CHANGED);
+        intent.putExtra(SubscriptionManager.EXTRA_SUBSCRIPTION_INDEX, subId);
+
+        receiverCaptor.getValue().onReceive(mContext, intent);
+    }
+
     private OnSubscriptionsChangedListener getOnSubscriptionsChangedListener() {
         final ArgumentCaptor<OnSubscriptionsChangedListener> listenerCaptor =
                 ArgumentCaptor.forClass(OnSubscriptionsChangedListener.class);
@@ -1128,4 +1143,53 @@
                             writeTime * 3 - startTime1 - startTime2 - startTime3)
                 });
     }
+
+    @Test
+    public void testUpdateDefaultSubId() {
+        final int startTime1 = 1000;
+        final int startTime2 = 3000;
+        final int writeTime = 5000;
+
+        // No TelephonyNetworkSpecifier set with subId to force the use of default subId.
+        final NetworkCapabilities nc =
+                new NetworkCapabilities.Builder().addTransportType(TRANSPORT_CELLULAR).build();
+        onStartKeepalive(startTime1, TEST_SLOT, nc);
+        // Update default subId
+        triggerBroadcastDefaultSubId(TEST_SUB_ID_1);
+        onStartKeepalive(startTime2, TEST_SLOT2, nc);
+
+        final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
+                buildKeepaliveMetrics(writeTime);
+
+        final int[] expectRegisteredDurations =
+                new int[] {startTime1, startTime2 - startTime1, writeTime - startTime2};
+        final int[] expectActiveDurations =
+                new int[] {startTime1, startTime2 - startTime1, writeTime - startTime2};
+        // Expect the carrier id of the first keepalive to be unknown
+        final KeepaliveCarrierStats expectKeepaliveCarrierStats1 =
+                new KeepaliveCarrierStats(
+                        TelephonyManager.UNKNOWN_CARRIER_ID,
+                        /* transportTypes= */ (1 << TRANSPORT_CELLULAR),
+                        TEST_KEEPALIVE_INTERVAL_SEC * 1000,
+                        writeTime - startTime1,
+                        writeTime - startTime1);
+        // Expect the carrier id of the second keepalive to be TEST_CARRIER_ID_1, from TEST_SUB_ID_1
+        final KeepaliveCarrierStats expectKeepaliveCarrierStats2 =
+                new KeepaliveCarrierStats(
+                        TEST_CARRIER_ID_1,
+                        /* transportTypes= */ (1 << TRANSPORT_CELLULAR),
+                        TEST_KEEPALIVE_INTERVAL_SEC * 1000,
+                        writeTime - startTime2,
+                        writeTime - startTime2);
+        assertDailyKeepaliveInfoReported(
+                dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 2,
+                /* expectAutoRequestsCount= */ 2,
+                /* expectAppUids= */ new int[] {TEST_UID},
+                expectRegisteredDurations,
+                expectActiveDurations,
+                new KeepaliveCarrierStats[] {
+                    expectKeepaliveCarrierStats1, expectKeepaliveCarrierStats2
+                });
+    }
 }