Merge "bpf_progs: disable BTF on <=U && user builds"
diff --git a/framework/src/android/net/NattKeepalivePacketData.java b/framework/src/android/net/NattKeepalivePacketData.java
index a18e713..9e6d80d 100644
--- a/framework/src/android/net/NattKeepalivePacketData.java
+++ b/framework/src/android/net/NattKeepalivePacketData.java
@@ -29,7 +29,9 @@
 import com.android.net.module.util.IpUtils;
 
 import java.net.Inet4Address;
+import java.net.Inet6Address;
 import java.net.InetAddress;
+import java.net.UnknownHostException;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 import java.util.Objects;
@@ -38,6 +40,7 @@
 @SystemApi
 public final class NattKeepalivePacketData extends KeepalivePacketData implements Parcelable {
     private static final int IPV4_HEADER_LENGTH = 20;
+    private static final int IPV6_HEADER_LENGTH = 40;
     private static final int UDP_HEADER_LENGTH = 8;
 
     // This should only be constructed via static factory methods, such as
@@ -59,13 +62,25 @@
             throw new InvalidPacketException(ERROR_INVALID_PORT);
         }
 
-        if (!(srcAddress instanceof Inet4Address) || !(dstAddress instanceof Inet4Address)) {
+        // Convert IPv4 mapped v6 address to v4 if any.
+        final InetAddress srcAddr, dstAddr;
+        try {
+            srcAddr = InetAddress.getByAddress(srcAddress.getAddress());
+            dstAddr = InetAddress.getByAddress(dstAddress.getAddress());
+        } catch (UnknownHostException e) {
             throw new InvalidPacketException(ERROR_INVALID_IP_ADDRESS);
         }
 
-        return nattKeepalivePacketv4(
-                (Inet4Address) srcAddress, srcPort,
-                (Inet4Address) dstAddress, dstPort);
+        if (srcAddr instanceof Inet4Address && dstAddr instanceof Inet4Address) {
+            return nattKeepalivePacketv4(
+                    (Inet4Address) srcAddr, srcPort, (Inet4Address) dstAddr, dstPort);
+        } else if (srcAddr instanceof Inet6Address && dstAddr instanceof Inet6Address) {
+            return nattKeepalivePacketv6(
+                    (Inet6Address) srcAddr, srcPort, (Inet6Address) dstAddr, dstPort);
+        } else {
+            // Destination address and source address should be the same IP family.
+            throw new InvalidPacketException(ERROR_INVALID_IP_ADDRESS);
+        }
     }
 
     private static NattKeepalivePacketData nattKeepalivePacketv4(
@@ -82,14 +97,14 @@
         // /proc/sys/net/ipv4/ip_default_ttl. Use hard-coded 64 for simplicity.
         buf.put((byte) 64);                                 // TTL
         buf.put((byte) OsConstants.IPPROTO_UDP);
-        int ipChecksumOffset = buf.position();
+        final int ipChecksumOffset = buf.position();
         buf.putShort((short) 0);                            // IP checksum
         buf.put(srcAddress.getAddress());
         buf.put(dstAddress.getAddress());
         buf.putShort((short) srcPort);
         buf.putShort((short) dstPort);
         buf.putShort((short) (UDP_HEADER_LENGTH + 1));      // UDP length
-        int udpChecksumOffset = buf.position();
+        final int udpChecksumOffset = buf.position();
         buf.putShort((short) 0);                            // UDP checksum
         buf.put((byte) 0xff);                               // NAT-T keepalive
         buf.putShort(ipChecksumOffset, IpUtils.ipChecksum(buf, 0));
@@ -98,6 +113,30 @@
         return new NattKeepalivePacketData(srcAddress, srcPort, dstAddress, dstPort, buf.array());
     }
 
+    private static NattKeepalivePacketData nattKeepalivePacketv6(
+            Inet6Address srcAddress, int srcPort, Inet6Address dstAddress, int dstPort)
+            throws InvalidPacketException {
+        final ByteBuffer buf = ByteBuffer.allocate(IPV6_HEADER_LENGTH + UDP_HEADER_LENGTH + 1);
+        buf.order(ByteOrder.BIG_ENDIAN);
+        buf.putInt(0x60000000);                         // IP version, traffic class and flow label
+        buf.putShort((short) (UDP_HEADER_LENGTH + 1));  // Payload length
+        buf.put((byte) OsConstants.IPPROTO_UDP);        // Next header
+        // For native ipv6, this hop limit value should use the per interface v6 hoplimit sysctl.
+        // For 464xlat, this value should use the v4 ttl sysctl.
+        // Either way, for simplicity, just hard code 64.
+        buf.put((byte) 64);                             // Hop limit
+        buf.put(srcAddress.getAddress());
+        buf.put(dstAddress.getAddress());
+        // UDP
+        buf.putShort((short) srcPort);
+        buf.putShort((short) dstPort);
+        buf.putShort((short) (UDP_HEADER_LENGTH + 1));  // UDP length = Payload length
+        final int udpChecksumOffset = buf.position();
+        buf.putShort((short) 0);                        // UDP checksum
+        buf.put((byte) 0xff);                           // NAT-T keepalive. 1 byte of data
+        buf.putShort(udpChecksumOffset, IpUtils.udpChecksum(buf, 0, IPV6_HEADER_LENGTH));
+        return new NattKeepalivePacketData(srcAddress, srcPort, dstAddress, dstPort, buf.array());
+    }
     /** Parcelable Implementation */
     public int describeContents() {
         return 0;
diff --git a/netd/BpfHandler.cpp b/netd/BpfHandler.cpp
index 3984249..d239277 100644
--- a/netd/BpfHandler.cpp
+++ b/netd/BpfHandler.cpp
@@ -52,25 +52,25 @@
 static Status attachProgramToCgroup(const char* programPath, const unique_fd& cgroupFd,
                                     bpf_attach_type type) {
     unique_fd cgroupProg(retrieveProgram(programPath));
-    if (cgroupProg == -1) {
-        int ret = errno;
-        ALOGE("Failed to get program from %s: %s", programPath, strerror(ret));
-        return statusFromErrno(ret, "cgroup program get failed");
+    if (!cgroupProg.ok()) {
+        const int err = errno;
+        ALOGE("Failed to get program from %s: %s", programPath, strerror(err));
+        return statusFromErrno(err, "cgroup program get failed");
     }
     if (android::bpf::attachProgram(type, cgroupProg, cgroupFd)) {
-        int ret = errno;
-        ALOGE("Program from %s attach failed: %s", programPath, strerror(ret));
-        return statusFromErrno(ret, "program attach failed");
+        const int err = errno;
+        ALOGE("Program from %s attach failed: %s", programPath, strerror(err));
+        return statusFromErrno(err, "program attach failed");
     }
     return netdutils::status::ok;
 }
 
 static Status checkProgramAccessible(const char* programPath) {
     unique_fd prog(retrieveProgram(programPath));
-    if (prog == -1) {
-        int ret = errno;
-        ALOGE("Failed to get program from %s: %s", programPath, strerror(ret));
-        return statusFromErrno(ret, "program retrieve failed");
+    if (!prog.ok()) {
+        const int err = errno;
+        ALOGE("Failed to get program from %s: %s", programPath, strerror(err));
+        return statusFromErrno(err, "program retrieve failed");
     }
     return netdutils::status::ok;
 }
@@ -79,10 +79,10 @@
     if (modules::sdklevel::IsAtLeastU() && !!strcmp(cg2_path, "/sys/fs/cgroup")) abort();
 
     unique_fd cg_fd(open(cg2_path, O_DIRECTORY | O_RDONLY | O_CLOEXEC));
-    if (cg_fd == -1) {
-        const int ret = errno;
-        ALOGE("Failed to open the cgroup directory: %s", strerror(ret));
-        return statusFromErrno(ret, "Open the cgroup directory failed");
+    if (!cg_fd.ok()) {
+        const int err = errno;
+        ALOGE("Failed to open the cgroup directory: %s", strerror(err));
+        return statusFromErrno(err, "Open the cgroup directory failed");
     }
     RETURN_IF_NOT_OK(checkProgramAccessible(XT_BPF_ALLOWLIST_PROG_PATH));
     RETURN_IF_NOT_OK(checkProgramAccessible(XT_BPF_DENYLIST_PROG_PATH));
diff --git a/service-t/src/com/android/server/connectivity/mdns/ExecutorProvider.java b/service-t/src/com/android/server/connectivity/mdns/ExecutorProvider.java
index 72b65e0..0eebc61 100644
--- a/service-t/src/com/android/server/connectivity/mdns/ExecutorProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/ExecutorProvider.java
@@ -42,6 +42,9 @@
     /** Shuts down all the created {@link ScheduledExecutorService} instances. */
     public void shutdownAll() {
         for (ScheduledExecutorService executor : serviceTypeClientSchedulerExecutors) {
+            if (executor.isShutdown()) {
+                continue;
+            }
             executor.shutdownNow();
         }
     }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
index 5f27b6a..158d7a3 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsAdvertiser.java
@@ -287,7 +287,7 @@
         }
 
         @Override
-        public void onSocketCreated(@NonNull Network network,
+        public void onSocketCreated(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket,
                 @NonNull List<LinkAddress> addresses) {
             MdnsInterfaceAdvertiser advertiser = mAllAdvertisers.get(socket);
@@ -311,14 +311,14 @@
         }
 
         @Override
-        public void onInterfaceDestroyed(@NonNull Network network,
+        public void onInterfaceDestroyed(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket) {
             final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket);
             if (advertiser != null) advertiser.destroyNow();
         }
 
         @Override
-        public void onAddressesChanged(@NonNull Network network,
+        public void onAddressesChanged(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
             final MdnsInterfaceAdvertiser advertiser = mAdvertisers.get(socket);
             if (advertiser != null) advertiser.updateAddresses(addresses);
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
index 39fceb9..afad3b7 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsDiscoveryManager.java
@@ -22,7 +22,6 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.annotation.RequiresPermission;
-import android.net.Network;
 import android.os.Handler;
 import android.os.HandlerThread;
 import android.util.ArrayMap;
@@ -36,7 +35,6 @@
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
-import java.util.Objects;
 
 /**
  * This class keeps tracking the set of registered {@link MdnsServiceBrowserListener} instances, and
@@ -50,54 +48,58 @@
     private final MdnsSocketClientBase socketClient;
     @NonNull private final SharedLog sharedLog;
 
-    @NonNull private final PerNetworkServiceTypeClients perNetworkServiceTypeClients;
+    @NonNull private final PerSocketServiceTypeClients perSocketServiceTypeClients;
     @NonNull private final Handler handler;
     @Nullable private final HandlerThread handlerThread;
 
-    private static class PerNetworkServiceTypeClients {
-        private final ArrayMap<Pair<String, Network>, MdnsServiceTypeClient> clients =
+    private static class PerSocketServiceTypeClients {
+        private final ArrayMap<Pair<String, SocketKey>, MdnsServiceTypeClient> clients =
                 new ArrayMap<>();
 
-        public void put(@NonNull String serviceType, @Nullable Network network,
+        public void put(@NonNull String serviceType, @NonNull SocketKey socketKey,
                 @NonNull MdnsServiceTypeClient client) {
             final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType);
-            final Pair<String, Network> perNetworkServiceType = new Pair<>(dnsLowerServiceType,
-                    network);
-            clients.put(perNetworkServiceType, client);
+            final Pair<String, SocketKey> perSocketServiceType = new Pair<>(dnsLowerServiceType,
+                    socketKey);
+            clients.put(perSocketServiceType, client);
         }
 
         @Nullable
-        public MdnsServiceTypeClient get(@NonNull String serviceType, @Nullable Network network) {
+        public MdnsServiceTypeClient get(
+                @NonNull String serviceType, @NonNull SocketKey socketKey) {
             final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType);
-            final Pair<String, Network> perNetworkServiceType = new Pair<>(dnsLowerServiceType,
-                    network);
-            return clients.getOrDefault(perNetworkServiceType, null);
+            final Pair<String, SocketKey> perSocketServiceType = new Pair<>(dnsLowerServiceType,
+                    socketKey);
+            return clients.getOrDefault(perSocketServiceType, null);
         }
 
         public List<MdnsServiceTypeClient> getByServiceType(@NonNull String serviceType) {
             final String dnsLowerServiceType = MdnsUtils.toDnsLowerCase(serviceType);
             final List<MdnsServiceTypeClient> list = new ArrayList<>();
             for (int i = 0; i < clients.size(); i++) {
-                final Pair<String, Network> perNetworkServiceType = clients.keyAt(i);
-                if (dnsLowerServiceType.equals(perNetworkServiceType.first)) {
+                final Pair<String, SocketKey> perSocketServiceType = clients.keyAt(i);
+                if (dnsLowerServiceType.equals(perSocketServiceType.first)) {
                     list.add(clients.valueAt(i));
                 }
             }
             return list;
         }
 
-        public List<MdnsServiceTypeClient> getByNetwork(@Nullable Network network) {
+        public List<MdnsServiceTypeClient> getBySocketKey(@NonNull SocketKey socketKey) {
             final List<MdnsServiceTypeClient> list = new ArrayList<>();
             for (int i = 0; i < clients.size(); i++) {
-                final Pair<String, Network> perNetworkServiceType = clients.keyAt(i);
-                final Network serviceTypeNetwork = perNetworkServiceType.second;
-                if (Objects.equals(network, serviceTypeNetwork)) {
+                final Pair<String, SocketKey> perSocketServiceType = clients.keyAt(i);
+                if (socketKey.equals(perSocketServiceType.second)) {
                     list.add(clients.valueAt(i));
                 }
             }
             return list;
         }
 
+        public List<MdnsServiceTypeClient> getAllMdnsServiceTypeClient() {
+            return new ArrayList<>(clients.values());
+        }
+
         public void remove(@NonNull MdnsServiceTypeClient client) {
             final int index = clients.indexOfValue(client);
             clients.removeAt(index);
@@ -113,7 +115,7 @@
         this.executorProvider = executorProvider;
         this.socketClient = socketClient;
         this.sharedLog = sharedLog;
-        this.perNetworkServiceTypeClients = new PerNetworkServiceTypeClients();
+        this.perSocketServiceTypeClients = new PerSocketServiceTypeClients();
         if (socketClient.getLooper() != null) {
             this.handlerThread = null;
             this.handler = new Handler(socketClient.getLooper());
@@ -164,7 +166,7 @@
             @NonNull String serviceType,
             @NonNull MdnsServiceBrowserListener listener,
             @NonNull MdnsSearchOptions searchOptions) {
-        if (perNetworkServiceTypeClients.isEmpty()) {
+        if (perSocketServiceTypeClients.isEmpty()) {
             // First listener. Starts the socket client.
             try {
                 socketClient.startDiscovery();
@@ -177,29 +179,29 @@
         socketClient.notifyNetworkRequested(listener, searchOptions.getNetwork(),
                 new MdnsSocketClientBase.SocketCreationCallback() {
                     @Override
-                    public void onSocketCreated(@Nullable Network network) {
+                    public void onSocketCreated(@NonNull SocketKey socketKey) {
                         ensureRunningOnHandlerThread(handler);
                         // All listeners of the same service types shares the same
                         // MdnsServiceTypeClient.
                         MdnsServiceTypeClient serviceTypeClient =
-                                perNetworkServiceTypeClients.get(serviceType, network);
+                                perSocketServiceTypeClients.get(serviceType, socketKey);
                         if (serviceTypeClient == null) {
-                            serviceTypeClient = createServiceTypeClient(serviceType, network);
-                            perNetworkServiceTypeClients.put(serviceType, network,
+                            serviceTypeClient = createServiceTypeClient(serviceType, socketKey);
+                            perSocketServiceTypeClients.put(serviceType, socketKey,
                                     serviceTypeClient);
                         }
                         serviceTypeClient.startSendAndReceive(listener, searchOptions);
                     }
 
                     @Override
-                    public void onAllSocketsDestroyed(@Nullable Network network) {
+                    public void onAllSocketsDestroyed(@NonNull SocketKey socketKey) {
                         ensureRunningOnHandlerThread(handler);
                         final MdnsServiceTypeClient serviceTypeClient =
-                                perNetworkServiceTypeClients.get(serviceType, network);
+                                perSocketServiceTypeClients.get(serviceType, socketKey);
                         if (serviceTypeClient == null) return;
                         // Notify all listeners that all services are removed from this socket.
                         serviceTypeClient.notifySocketDestroyed();
-                        perNetworkServiceTypeClients.remove(serviceTypeClient);
+                        perSocketServiceTypeClients.remove(serviceTypeClient);
                     }
                 });
     }
@@ -224,7 +226,7 @@
         socketClient.notifyNetworkUnrequested(listener);
 
         final List<MdnsServiceTypeClient> serviceTypeClients =
-                perNetworkServiceTypeClients.getByServiceType(serviceType);
+                perSocketServiceTypeClients.getByServiceType(serviceType);
         if (serviceTypeClients.isEmpty()) {
             return;
         }
@@ -233,60 +235,59 @@
             if (serviceTypeClient.stopSendAndReceive(listener)) {
                 // No listener is registered for the service type anymore, remove it from the list
                 // of the service type clients.
-                perNetworkServiceTypeClients.remove(serviceTypeClient);
+                perSocketServiceTypeClients.remove(serviceTypeClient);
             }
         }
-        if (perNetworkServiceTypeClients.isEmpty()) {
+        if (perSocketServiceTypeClients.isEmpty()) {
             // No discovery request. Stops the socket client.
+            sharedLog.i("All service type listeners unregistered; stopping discovery");
             socketClient.stopDiscovery();
         }
     }
 
     @Override
-    public void onResponseReceived(@NonNull MdnsPacket packet,
-            int interfaceIndex, @Nullable Network network) {
+    public void onResponseReceived(@NonNull MdnsPacket packet, @NonNull SocketKey socketKey) {
         checkAndRunOnHandlerThread(() ->
-                handleOnResponseReceived(packet, interfaceIndex, network));
+                handleOnResponseReceived(packet, socketKey));
     }
 
-    private void handleOnResponseReceived(@NonNull MdnsPacket packet, int interfaceIndex,
-            @Nullable Network network) {
-        for (MdnsServiceTypeClient serviceTypeClient
-                : getMdnsServiceTypeClient(network)) {
-            serviceTypeClient.processResponse(packet, interfaceIndex, network);
+    private void handleOnResponseReceived(@NonNull MdnsPacket packet,
+            @NonNull SocketKey socketKey) {
+        for (MdnsServiceTypeClient serviceTypeClient : getMdnsServiceTypeClient(socketKey)) {
+            serviceTypeClient.processResponse(
+                    packet, socketKey.getInterfaceIndex(), socketKey.getNetwork());
         }
     }
 
-    private List<MdnsServiceTypeClient> getMdnsServiceTypeClient(@Nullable Network network) {
+    private List<MdnsServiceTypeClient> getMdnsServiceTypeClient(@NonNull SocketKey socketKey) {
         if (socketClient.supportsRequestingSpecificNetworks()) {
-            return perNetworkServiceTypeClients.getByNetwork(network);
+            return perSocketServiceTypeClients.getBySocketKey(socketKey);
         } else {
-            return perNetworkServiceTypeClients.getByNetwork(null);
+            return perSocketServiceTypeClients.getAllMdnsServiceTypeClient();
         }
     }
 
     @Override
     public void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
-            @Nullable Network network) {
+            @NonNull SocketKey socketKey) {
         checkAndRunOnHandlerThread(() ->
-                handleOnFailedToParseMdnsResponse(receivedPacketNumber, errorCode, network));
+                handleOnFailedToParseMdnsResponse(receivedPacketNumber, errorCode, socketKey));
     }
 
     private void handleOnFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
-            @Nullable Network network) {
-        for (MdnsServiceTypeClient serviceTypeClient
-                : getMdnsServiceTypeClient(network)) {
+            @NonNull SocketKey socketKey) {
+        for (MdnsServiceTypeClient serviceTypeClient : getMdnsServiceTypeClient(socketKey)) {
             serviceTypeClient.onFailedToParseMdnsResponse(receivedPacketNumber, errorCode);
         }
     }
 
     @VisibleForTesting
     MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
-            @Nullable Network network) {
-        sharedLog.log("createServiceTypeClient for type:" + serviceType + ", net:" + network);
+            @NonNull SocketKey socketKey) {
+        sharedLog.log("createServiceTypeClient for type:" + serviceType + " " + socketKey);
         return new MdnsServiceTypeClient(
                 serviceType, socketClient,
-                executorProvider.newServiceTypeClientSchedulerExecutor(), network,
-                sharedLog.forSubComponent(serviceType + "-" + network));
+                executorProvider.newServiceTypeClientSchedulerExecutor(), socketKey,
+                sharedLog.forSubComponent(serviceType + "-" + socketKey));
     }
 }
\ No newline at end of file
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
index 73e4497..d0ca20e 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClient.java
@@ -64,7 +64,7 @@
         @NonNull
         private final SocketCreationCallback mSocketCreationCallback;
         @NonNull
-        private final ArrayMap<MdnsInterfaceSocket, Network> mActiveNetworkSockets =
+        private final ArrayMap<MdnsInterfaceSocket, SocketKey> mActiveNetworkSockets =
                 new ArrayMap<>();
 
         InterfaceSocketCallback(SocketCreationCallback socketCreationCallback) {
@@ -72,32 +72,32 @@
         }
 
         @Override
-        public void onSocketCreated(@Nullable Network network,
+        public void onSocketCreated(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {
             // The socket may be already created by other request before, try to get the stored
             // ReadPacketHandler.
             ReadPacketHandler handler = mSocketPacketHandlers.get(socket);
             if (handler == null) {
                 // First request to create this socket. Initial a ReadPacketHandler for this socket.
-                handler = new ReadPacketHandler(network, socket.getInterface().getIndex());
+                handler = new ReadPacketHandler(socketKey);
                 mSocketPacketHandlers.put(socket, handler);
             }
             socket.addPacketHandler(handler);
-            mActiveNetworkSockets.put(socket, network);
-            mSocketCreationCallback.onSocketCreated(network);
+            mActiveNetworkSockets.put(socket, socketKey);
+            mSocketCreationCallback.onSocketCreated(socketKey);
         }
 
         @Override
-        public void onInterfaceDestroyed(@Nullable Network network,
+        public void onInterfaceDestroyed(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket) {
             notifySocketDestroyed(socket);
             maybeCleanupPacketHandler(socket);
         }
 
         private void notifySocketDestroyed(@NonNull MdnsInterfaceSocket socket) {
-            final Network network = mActiveNetworkSockets.remove(socket);
-            if (!isAnySocketActive(network)) {
-                mSocketCreationCallback.onAllSocketsDestroyed(network);
+            final SocketKey socketKey = mActiveNetworkSockets.remove(socket);
+            if (!isAnySocketActive(socketKey)) {
+                mSocketCreationCallback.onAllSocketsDestroyed(socketKey);
             }
         }
 
@@ -121,18 +121,18 @@
         return false;
     }
 
-    private boolean isAnySocketActive(@Nullable Network network) {
+    private boolean isAnySocketActive(@NonNull SocketKey socketKey) {
         for (int i = 0; i < mRequestedNetworks.size(); i++) {
             final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
-            if (isc.mActiveNetworkSockets.containsValue(network)) {
+            if (isc.mActiveNetworkSockets.containsValue(socketKey)) {
                 return true;
             }
         }
         return false;
     }
 
-    private ArrayMap<MdnsInterfaceSocket, Network> getActiveSockets() {
-        final ArrayMap<MdnsInterfaceSocket, Network> sockets = new ArrayMap<>();
+    private ArrayMap<MdnsInterfaceSocket, SocketKey> getActiveSockets() {
+        final ArrayMap<MdnsInterfaceSocket, SocketKey> sockets = new ArrayMap<>();
         for (int i = 0; i < mRequestedNetworks.size(); i++) {
             final InterfaceSocketCallback isc = mRequestedNetworks.valueAt(i);
             sockets.putAll(isc.mActiveNetworkSockets);
@@ -146,17 +146,15 @@
     }
 
     private class ReadPacketHandler implements MulticastPacketReader.PacketHandler {
-        private final Network mNetwork;
-        private final int mInterfaceIndex;
+        @NonNull private final SocketKey mSocketKey;
 
-        ReadPacketHandler(@NonNull Network network, int interfaceIndex) {
-            mNetwork = network;
-            mInterfaceIndex = interfaceIndex;
+        ReadPacketHandler(@NonNull SocketKey socketKey) {
+            mSocketKey = socketKey;
         }
 
         @Override
         public void handlePacket(byte[] recvbuf, int length, InetSocketAddress src) {
-            processResponsePacket(recvbuf, length, mInterfaceIndex, mNetwork);
+            processResponsePacket(recvbuf, length, mSocketKey);
         }
     }
 
@@ -220,10 +218,10 @@
                 instanceof Inet6Address;
         final boolean isIpv4 = ((InetSocketAddress) packet.getSocketAddress()).getAddress()
                 instanceof Inet4Address;
-        final ArrayMap<MdnsInterfaceSocket, Network> activeSockets = getActiveSockets();
+        final ArrayMap<MdnsInterfaceSocket, SocketKey> activeSockets = getActiveSockets();
         for (int i = 0; i < activeSockets.size(); i++) {
             final MdnsInterfaceSocket socket = activeSockets.keyAt(i);
-            final Network network = activeSockets.valueAt(i);
+            final Network network = activeSockets.valueAt(i).getNetwork();
             // Check ip capability and network before sending packet
             if (((isIpv6 && socket.hasJoinedIpv6()) || (isIpv4 && socket.hasJoinedIpv4()))
                     // Contrary to MdnsUtils.isNetworkMatched, only send packets targeting
@@ -239,8 +237,7 @@
         }
     }
 
-    private void processResponsePacket(byte[] recvbuf, int length, int interfaceIndex,
-            @NonNull Network network) {
+    private void processResponsePacket(byte[] recvbuf, int length, @NonNull SocketKey socketKey) {
         int packetNumber = ++mReceivedPacketNumber;
 
         final MdnsPacket response;
@@ -250,14 +247,14 @@
             if (e.code != MdnsResponseErrorCode.ERROR_NOT_RESPONSE_MESSAGE) {
                 Log.e(TAG, e.getMessage(), e);
                 if (mCallback != null) {
-                    mCallback.onFailedToParseMdnsResponse(packetNumber, e.code, network);
+                    mCallback.onFailedToParseMdnsResponse(packetNumber, e.code, socketKey);
                 }
             }
             return;
         }
 
         if (mCallback != null) {
-            mCallback.onResponseReceived(response, interfaceIndex, network);
+            mCallback.onResponseReceived(response, socketKey);
         }
     }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
index 49a376c..bdc673e 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsServiceTypeClient.java
@@ -58,7 +58,7 @@
     private final MdnsSocketClientBase socketClient;
     private final MdnsResponseDecoder responseDecoder;
     private final ScheduledExecutorService executor;
-    @Nullable private final Network network;
+    @NonNull private final SocketKey socketKey;
     @NonNull private final SharedLog sharedLog;
     private final Object lock = new Object();
     private final ArrayMap<MdnsServiceBrowserListener, MdnsSearchOptions> listeners =
@@ -90,9 +90,9 @@
             @NonNull String serviceType,
             @NonNull MdnsSocketClientBase socketClient,
             @NonNull ScheduledExecutorService executor,
-            @Nullable Network network,
+            @NonNull SocketKey socketKey,
             @NonNull SharedLog sharedLog) {
-        this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), network,
+        this(serviceType, socketClient, executor, new MdnsResponseDecoder.Clock(), socketKey,
                 sharedLog);
     }
 
@@ -102,7 +102,7 @@
             @NonNull MdnsSocketClientBase socketClient,
             @NonNull ScheduledExecutorService executor,
             @NonNull MdnsResponseDecoder.Clock clock,
-            @Nullable Network network,
+            @NonNull SocketKey socketKey,
             @NonNull SharedLog sharedLog) {
         this.serviceType = serviceType;
         this.socketClient = socketClient;
@@ -110,7 +110,7 @@
         this.serviceTypeLabels = TextUtils.split(serviceType, "\\.");
         this.responseDecoder = new MdnsResponseDecoder(clock, serviceTypeLabels);
         this.clock = clock;
-        this.network = network;
+        this.socketKey = socketKey;
         this.sharedLog = sharedLog;
     }
 
@@ -199,7 +199,7 @@
                     searchOptions.getSubtypes(),
                     searchOptions.isPassiveMode(),
                     currentSessionId,
-                    network);
+                    socketKey);
             if (hadReply) {
                 requestTaskFuture = scheduleNextRunLocked(taskConfig);
             } else {
@@ -348,6 +348,11 @@
             boolean after = response.isComplete();
             serviceBecomesComplete = !before && after;
         }
+        sharedLog.i(String.format(
+                "Handling response from service: %s, newServiceFound: %b, serviceBecomesComplete:"
+                        + " %b, responseIsComplete: %b",
+                serviceInstanceName, newServiceFound, serviceBecomesComplete,
+                response.isComplete()));
         MdnsServiceInfo serviceInfo =
                 buildMdnsServiceInfoFromResponse(response, serviceTypeLabels);
 
@@ -432,10 +437,10 @@
         private int burstCounter;
         private int timeToRunNextTaskInMs;
         private boolean isFirstBurst;
-        @Nullable private final Network network;
+        @NonNull private final SocketKey socketKey;
 
         QueryTaskConfig(@NonNull Collection<String> subtypes, boolean usePassiveMode,
-                long sessionId, @Nullable Network network) {
+                long sessionId, @NonNull SocketKey socketKey) {
             this.usePassiveMode = usePassiveMode;
             this.subtypes = new ArrayList<>(subtypes);
             this.queriesPerBurst = QUERIES_PER_BURST;
@@ -457,7 +462,7 @@
                 // doubles until it maxes out at TIME_BETWEEN_BURSTS_MS.
                 this.timeBetweenBurstsInMs = INITIAL_TIME_BETWEEN_BURSTS_MS;
             }
-            this.network = network;
+            this.socketKey = socketKey;
         }
 
         QueryTaskConfig getConfigForNextRun() {
@@ -540,7 +545,7 @@
                 // Only the names are used to know which queries to send, other parameters like
                 // interfaceIndex do not matter.
                 servicesToResolve = makeResponsesForResolve(
-                        0 /* interfaceIndex */, config.network);
+                        0 /* interfaceIndex */, config.socketKey.getNetwork());
                 sendDiscoveryQueries = servicesToResolve.size() < listeners.size();
             }
             Pair<Integer, List<String>> result;
@@ -553,7 +558,7 @@
                                 config.subtypes,
                                 config.expectUnicastResponse,
                                 config.transactionId,
-                                config.network,
+                                config.socketKey.getNetwork(),
                                 sendDiscoveryQueries,
                                 servicesToResolve,
                                 clock)
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
index b982644..2b6e5d0 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClient.java
@@ -235,7 +235,7 @@
             throw new IllegalArgumentException("This socket client does not support requesting "
                     + "specific networks");
         }
-        socketCreationCallback.onSocketCreated(null);
+        socketCreationCallback.onSocketCreated(new SocketKey(multicastSocket.getInterfaceIndex()));
     }
 
     @Override
@@ -456,7 +456,8 @@
             LOGGER.w(String.format("Error while decoding %s packet (%d): %d",
                     responseType, packetNumber, e.code));
             if (callback != null) {
-                callback.onFailedToParseMdnsResponse(packetNumber, e.code, network);
+                callback.onFailedToParseMdnsResponse(packetNumber, e.code,
+                        new SocketKey(network, interfaceIndex));
             }
             return e.code;
         }
@@ -466,7 +467,8 @@
         }
 
         if (callback != null) {
-            callback.onResponseReceived(response, interfaceIndex, network);
+            callback.onResponseReceived(
+                    response, new SocketKey(network, interfaceIndex));
         }
 
         return MdnsResponseErrorCode.SUCCESS;
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
index e0762f9..a35925a 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketClientBase.java
@@ -73,20 +73,19 @@
     /*** Callback for mdns response  */
     interface Callback {
         /*** Receive a mdns response */
-        void onResponseReceived(@NonNull MdnsPacket packet, int interfaceIndex,
-                @Nullable Network network);
+        void onResponseReceived(@NonNull MdnsPacket packet, @NonNull SocketKey socketKey);
 
         /*** Parse a mdns response failed */
         void onFailedToParseMdnsResponse(int receivedPacketNumber, int errorCode,
-                @Nullable Network network);
+                @NonNull SocketKey socketKey);
     }
 
     /*** Callback for requested socket creation  */
     interface SocketCreationCallback {
         /*** Notify requested socket is created */
-        void onSocketCreated(@Nullable Network network);
+        void onSocketCreated(@NonNull SocketKey socketKey);
 
         /*** Notify requested socket is destroyed */
-        void onAllSocketsDestroyed(@Nullable Network network);
+        void onAllSocketsDestroyed(@NonNull SocketKey socketKey);
     }
 }
diff --git a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
index d90f67f..3df6313 100644
--- a/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
+++ b/service-t/src/com/android/server/connectivity/mdns/MdnsSocketProvider.java
@@ -258,6 +258,11 @@
                 @NonNull final NetLinkMonitorCallBack cb) {
             return SocketNetLinkMonitorFactory.createNetLinkMonitor(handler, log, cb);
         }
+
+        /*** Get interface index by given socket */
+        public int getInterfaceIndex(@NonNull MdnsInterfaceSocket socket) {
+            return socket.getInterface().getIndex();
+        }
     }
     /**
      * The callback interface for the netlink monitor messages.
@@ -597,8 +602,10 @@
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
-                mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(network, socketInfo.mSocket,
-                        socketInfo.mAddresses);
+                final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
+                final SocketKey socketKey = new SocketKey(network, ifaceIndex);
+                mCallbacksToRequestedNetworks.keyAt(i).onSocketCreated(socketKey,
+                        socketInfo.mSocket, socketInfo.mAddresses);
                 mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
                         socketInfo.mTransports);
             }
@@ -609,7 +616,9 @@
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
-                mCallbacksToRequestedNetworks.keyAt(i).onInterfaceDestroyed(network, socket);
+                final int ifaceIndex = mDependencies.getInterfaceIndex(socket);
+                mCallbacksToRequestedNetworks.keyAt(i)
+                        .onInterfaceDestroyed(new SocketKey(network, ifaceIndex), socket);
             }
         }
     }
@@ -619,8 +628,9 @@
         for (int i = 0; i < mCallbacksToRequestedNetworks.size(); i++) {
             final Network requestedNetwork = mCallbacksToRequestedNetworks.valueAt(i);
             if (isNetworkMatched(requestedNetwork, network)) {
+                final int ifaceIndex = mDependencies.getInterfaceIndex(socket);
                 mCallbacksToRequestedNetworks.keyAt(i)
-                        .onAddressesChanged(network, socket, addresses);
+                        .onAddressesChanged(new SocketKey(network, ifaceIndex), socket, addresses);
             }
         }
     }
@@ -637,7 +647,9 @@
             createSocket(new NetworkAsKey(network), lp);
         } else {
             // Notify the socket for requested network.
-            cb.onSocketCreated(network, socketInfo.mSocket, socketInfo.mAddresses);
+            final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
+            final SocketKey socketKey = new SocketKey(network, ifaceIndex);
+            cb.onSocketCreated(socketKey, socketInfo.mSocket, socketInfo.mAddresses);
             mSocketRequestMonitor.onSocketRequestFulfilled(network, socketInfo.mSocket,
                     socketInfo.mTransports);
         }
@@ -652,8 +664,9 @@
                     createLPForTetheredInterface(interfaceName, ifaceIndex));
         } else {
             // Notify the socket for requested network.
-            cb.onSocketCreated(
-                    null /* network */, socketInfo.mSocket, socketInfo.mAddresses);
+            final int ifaceIndex = mDependencies.getInterfaceIndex(socketInfo.mSocket);
+            final SocketKey socketKey = new SocketKey(ifaceIndex);
+            cb.onSocketCreated(socketKey, socketInfo.mSocket, socketInfo.mAddresses);
             mSocketRequestMonitor.onSocketRequestFulfilled(null /* socketNetwork */,
                     socketInfo.mSocket, socketInfo.mTransports);
         }
@@ -741,21 +754,21 @@
          * This may be called immediately when the request is registered with an existing socket,
          * if it had been created previously for other requests.
          */
-        default void onSocketCreated(@Nullable Network network, @NonNull MdnsInterfaceSocket socket,
-                @NonNull List<LinkAddress> addresses) {}
+        default void onSocketCreated(@NonNull SocketKey socketKey,
+                @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {}
 
         /**
          * Notify that the interface was destroyed, so the provided socket cannot be used anymore.
          *
          * This indicates that although the socket was still requested, it had to be destroyed.
          */
-        default void onInterfaceDestroyed(@Nullable Network network,
+        default void onInterfaceDestroyed(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket) {}
 
         /**
          * Notify the interface addresses have changed for the network.
          */
-        default void onAddressesChanged(@Nullable Network network,
+        default void onAddressesChanged(@NonNull SocketKey socketKey,
                 @NonNull MdnsInterfaceSocket socket, @NonNull List<LinkAddress> addresses) {}
     }
 
diff --git a/service-t/src/com/android/server/connectivity/mdns/SocketKey.java b/service-t/src/com/android/server/connectivity/mdns/SocketKey.java
new file mode 100644
index 0000000..a893acb
--- /dev/null
+++ b/service-t/src/com/android/server/connectivity/mdns/SocketKey.java
@@ -0,0 +1,72 @@
+/*
+ * Copyright (C) 2023 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivity.mdns;
+
+import android.annotation.Nullable;
+import android.net.Network;
+
+import java.util.Objects;
+
+/**
+ * A class that identifies a socket.
+ *
+ * <p> A socket is typically created with an associated network. However, tethering interfaces do
+ * not have an associated network, only an interface index. This means that the socket cannot be
+ * identified in some places. Therefore, this class is necessary for identifying a socket. It
+ * includes both the network and interface index.
+ */
+public class SocketKey {
+    @Nullable
+    private final Network mNetwork;
+    private final int mInterfaceIndex;
+
+    SocketKey(int interfaceIndex) {
+        this(null /* network */, interfaceIndex);
+    }
+
+    SocketKey(@Nullable Network network, int interfaceIndex) {
+        mNetwork = network;
+        mInterfaceIndex = interfaceIndex;
+    }
+
+    public Network getNetwork() {
+        return mNetwork;
+    }
+
+    public int getInterfaceIndex() {
+        return mInterfaceIndex;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(mNetwork, mInterfaceIndex);
+    }
+
+    @Override
+    public boolean equals(@Nullable Object other) {
+        if (!(other instanceof SocketKey)) {
+            return false;
+        }
+        return Objects.equals(mNetwork, ((SocketKey) other).mNetwork)
+                && mInterfaceIndex == ((SocketKey) other).mInterfaceIndex;
+    }
+
+    @Override
+    public String toString() {
+        return "SocketKey{ network=" + mNetwork + " interfaceIndex=" + mInterfaceIndex + " }";
+    }
+}
diff --git a/service-t/src/com/android/server/ethernet/EthernetNetworkFactory.java b/service-t/src/com/android/server/ethernet/EthernetNetworkFactory.java
index 6776920..ece10f3 100644
--- a/service-t/src/com/android/server/ethernet/EthernetNetworkFactory.java
+++ b/service-t/src/com/android/server/ethernet/EthernetNetworkFactory.java
@@ -313,17 +313,12 @@
                 mIpClientShutdownCv.block();
             }
 
-            // At the time IpClient is stopped, an IpClient event may have already been posted on
-            // the back of the handler and is awaiting execution. Once that event is executed, the
-            // associated callback object may not be valid anymore
-            // (NetworkInterfaceState#mIpClientCallback points to a different object / null).
-            private boolean isCurrentCallback() {
-                return this == mIpClientCallback;
-            }
-
-            private void handleIpEvent(final @NonNull Runnable r) {
+            private void safelyPostOnHandler(Runnable r) {
                 mHandler.post(() -> {
-                    if (!isCurrentCallback()) {
+                    if (this != mIpClientCallback) {
+                        // At the time IpClient is stopped, an IpClient event may have already been
+                        // posted on the handler and is awaiting execution. Once that event is
+                        // executed, the associated callback object may not be valid anymore.
                         Log.i(TAG, "Ignoring stale IpClientCallbacks " + this);
                         return;
                     }
@@ -333,24 +328,24 @@
 
             @Override
             public void onProvisioningSuccess(LinkProperties newLp) {
-                handleIpEvent(() -> onIpLayerStarted(newLp));
+                safelyPostOnHandler(() -> onIpLayerStarted(newLp));
             }
 
             @Override
             public void onProvisioningFailure(LinkProperties newLp) {
                 // This cannot happen due to provisioning timeout, because our timeout is 0. It can
                 // happen due to errors while provisioning or on provisioning loss.
-                handleIpEvent(() -> onIpLayerStopped());
+                safelyPostOnHandler(() -> onIpLayerStopped());
             }
 
             @Override
             public void onLinkPropertiesChange(LinkProperties newLp) {
-                handleIpEvent(() -> updateLinkProperties(newLp));
+                safelyPostOnHandler(() -> updateLinkProperties(newLp));
             }
 
             @Override
             public void onReachabilityLost(String logMsg) {
-                handleIpEvent(() -> updateNeighborLostEvent(logMsg));
+                safelyPostOnHandler(() -> updateNeighborLostEvent(logMsg));
             }
 
             @Override
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index d4d9233..39a500c 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -318,7 +318,6 @@
 import java.io.InterruptedIOException;
 import java.io.PrintWriter;
 import java.io.Writer;
-import java.lang.IllegalArgumentException;
 import java.net.Inet4Address;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
@@ -1432,6 +1431,7 @@
         /**
          * @see ClatCoordinator
          */
+        @RequiresApi(Build.VERSION_CODES.TIRAMISU)
         public ClatCoordinator getClatCoordinator(INetd netd) {
             return new ClatCoordinator(
                 new ClatCoordinator.Dependencies() {
@@ -4009,7 +4009,7 @@
                     // the destroyed flag is only just above the "current satisfier wins"
                     // tie-breaker. But technically anything that affects scoring should rematch.
                     rematchAllNetworksAndRequests();
-                    mHandler.postDelayed(() -> nai.disconnect(), timeoutMs);
+                    mHandler.postDelayed(() -> disconnectAndDestroyNetwork(nai), timeoutMs);
                     break;
                 }
             }
@@ -4608,6 +4608,9 @@
         if (DBG) {
             log(nai.toShortString() + " disconnected, was satisfying " + nai.numNetworkRequests());
         }
+
+        nai.disconnect();
+
         // Clear all notifications of this network.
         mNotifier.clearNotification(nai.network.getNetId());
         // A network agent has disconnected.
@@ -5892,7 +5895,7 @@
                     final NetworkAgentInfo nai = getNetworkAgentInfoForNetwork((Network) msg.obj);
                     if (nai == null) break;
                     nai.onPreventAutomaticReconnect();
-                    nai.disconnect();
+                    disconnectAndDestroyNetwork(nai);
                     break;
                 case EVENT_SET_VPN_NETWORK_PREFERENCE:
                     handleSetVpnNetworkPreference((VpnNetworkPreferenceInfo) msg.obj);
@@ -9037,7 +9040,7 @@
                 break;
             }
         }
-        nai.disconnect();
+        disconnectAndDestroyNetwork(nai);
     }
 
     private void handleLingerComplete(NetworkAgentInfo oldNetwork) {
@@ -9579,7 +9582,10 @@
         updateLegacyTypeTrackerAndVpnLockdownForRematch(changes, nais);
 
         // Tear down all unneeded networks.
-        for (NetworkAgentInfo nai : mNetworkAgentInfos) {
+        // Iterate in reverse order because teardownUnneededNetwork removes the nai from
+        // mNetworkAgentInfos.
+        for (int i = mNetworkAgentInfos.size() - 1; i >= 0; i--) {
+            final NetworkAgentInfo nai = mNetworkAgentInfos.valueAt(i);
             if (unneeded(nai, UnneededFor.TEARDOWN)) {
                 if (nai.getInactivityExpiry() > 0) {
                     // This network has active linger timers and no requests, but is not
@@ -9962,7 +9968,6 @@
             // This has to happen after matching the requests, because callbacks are just requests.
             notifyNetworkCallbacks(networkAgent, ConnectivityManager.CALLBACK_PRECHECK);
         } else if (state == NetworkInfo.State.DISCONNECTED) {
-            networkAgent.disconnect();
             if (networkAgent.isVPN()) {
                 updateVpnUids(networkAgent, networkAgent.networkCapabilities, null);
             }
diff --git a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
index 2adc028..6ba2033 100644
--- a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
@@ -94,6 +94,7 @@
     private static final int ADJUST_TCP_POLLING_DELAY_MS = 2000;
     private static final String AUTOMATIC_ON_OFF_KEEPALIVE_VERSION =
             "automatic_on_off_keepalive_version";
+    public static final long METRICS_COLLECTION_DURATION_MS = 24 * 60 * 60 * 1_000L;
 
     // ConnectivityService parses message constants from itself and AutomaticOnOffKeepaliveTracker
     // with MessageUtils for debugging purposes, and crashes if some messages have the same values.
@@ -180,6 +181,9 @@
     private final LocalLog mEventLog = new LocalLog(MAX_EVENTS_LOGS);
 
     private final KeepaliveStatsTracker mKeepaliveStatsTracker;
+
+    private final long mMetricsWriteTimeBase;
+
     /**
      * Information about a managed keepalive.
      *
@@ -248,7 +252,7 @@
         }
 
         public Network getNetwork() {
-            return mKi.getNai().network;
+            return mKi.getNai().network();
         }
 
         @Nullable
@@ -311,7 +315,26 @@
                 mContext, mConnectivityServiceHandler);
 
         mAlarmManager = mDependencies.getAlarmManager(context);
-        mKeepaliveStatsTracker = new KeepaliveStatsTracker(handler);
+        mKeepaliveStatsTracker =
+                mDependencies.newKeepaliveStatsTracker(context, handler);
+
+        final long time = mDependencies.getElapsedRealtime();
+        mMetricsWriteTimeBase = time % METRICS_COLLECTION_DURATION_MS;
+        final long triggerAtMillis = mMetricsWriteTimeBase + METRICS_COLLECTION_DURATION_MS;
+        mAlarmManager.set(AlarmManager.ELAPSED_REALTIME_WAKEUP, triggerAtMillis, TAG,
+                this::writeMetricsAndRescheduleAlarm, handler);
+    }
+
+    private void writeMetricsAndRescheduleAlarm() {
+        mKeepaliveStatsTracker.writeAndResetMetrics();
+
+        final long time = mDependencies.getElapsedRealtime();
+        final long triggerAtMillis =
+                mMetricsWriteTimeBase
+                        + (time - time % METRICS_COLLECTION_DURATION_MS)
+                        + METRICS_COLLECTION_DURATION_MS;
+        mAlarmManager.set(AlarmManager.ELAPSED_REALTIME_WAKEUP, triggerAtMillis, TAG,
+                this::writeMetricsAndRescheduleAlarm, mConnectivityServiceHandler);
     }
 
     private void startTcpPollingAlarm(@NonNull AutomaticOnOffKeepalive ki) {
@@ -455,7 +478,13 @@
             return;
         }
         mEventLog.log("Start keepalive " + autoKi.mCallback + " on " + autoKi.getNetwork());
-        mKeepaliveStatsTracker.onStartKeepalive();
+        mKeepaliveStatsTracker.onStartKeepalive(
+                autoKi.getNetwork(),
+                autoKi.mKi.getSlot(),
+                autoKi.mKi.getNai().networkCapabilities,
+                autoKi.mKi.getKeepaliveIntervalSec(),
+                autoKi.mKi.getUid(),
+                STATE_ALWAYS_ON != autoKi.mAutomaticOnOffState);
 
         // Add automatic on/off request into list to track its life cycle.
         try {
@@ -483,7 +512,7 @@
                     + " with error " + error);
             return error;
         }
-        mKeepaliveStatsTracker.onResumeKeepalive();
+        mKeepaliveStatsTracker.onResumeKeepalive(ki.getNai().network(), ki.getSlot());
         mEventLog.log("Resumed successfully keepalive " + ki.mCallback + " on " + ki.mNai);
 
         return SUCCESS;
@@ -491,7 +520,7 @@
 
     private void handlePauseKeepalive(@NonNull final KeepaliveTracker.KeepaliveInfo ki) {
         mEventLog.log("Suspend keepalive " + ki.mCallback + " on " + ki.mNai);
-        mKeepaliveStatsTracker.onPauseKeepalive();
+        mKeepaliveStatsTracker.onPauseKeepalive(ki.getNai().network(), ki.getSlot());
         // TODO : mKT.handleStopKeepalive should take a KeepaliveInfo instead
         mKeepaliveTracker.handleStopKeepalive(ki.getNai(), ki.getSlot(), SUCCESS_PAUSED);
     }
@@ -515,7 +544,7 @@
 
     private void cleanupAutoOnOffKeepalive(@NonNull final AutomaticOnOffKeepalive autoKi) {
         ensureRunningOnHandlerThread();
-        mKeepaliveStatsTracker.onStopKeepalive(autoKi.mAutomaticOnOffState != STATE_SUSPENDED);
+        mKeepaliveStatsTracker.onStopKeepalive(autoKi.getNetwork(), autoKi.mKi.getSlot());
         autoKi.close();
         if (null != autoKi.mAlarmListener) mAlarmManager.cancel(autoKi.mAlarmListener);
 
@@ -892,6 +921,14 @@
         }
 
         /**
+         * Construct a new KeepaliveStatsTracker.
+         */
+        public KeepaliveStatsTracker newKeepaliveStatsTracker(@NonNull Context context,
+                @NonNull Handler connectivityserviceHander) {
+            return new KeepaliveStatsTracker(context, connectivityserviceHander);
+        }
+
+        /**
          * Find out if a feature is enabled from DeviceConfig.
          *
          * @param name The name of the property to look up.
diff --git a/service/src/com/android/server/connectivity/ClatCoordinator.java b/service/src/com/android/server/connectivity/ClatCoordinator.java
index fbe706c..d87f250 100644
--- a/service/src/com/android/server/connectivity/ClatCoordinator.java
+++ b/service/src/com/android/server/connectivity/ClatCoordinator.java
@@ -30,6 +30,7 @@
 import android.net.InetAddresses;
 import android.net.InterfaceConfigurationParcel;
 import android.net.IpPrefix;
+import android.os.Build;
 import android.os.ParcelFileDescriptor;
 import android.os.RemoteException;
 import android.os.ServiceSpecificException;
@@ -58,11 +59,14 @@
 import java.nio.ByteBuffer;
 import java.util.Objects;
 
+import androidx.annotation.RequiresApi;
+
 /**
  * This coordinator is responsible for providing clat relevant functionality.
  *
  * {@hide}
  */
+@RequiresApi(Build.VERSION_CODES.TIRAMISU)
 public class ClatCoordinator {
     private static final String TAG = ClatCoordinator.class.getSimpleName();
 
@@ -251,11 +255,6 @@
         /** Get ingress6 BPF map. */
         @Nullable
         public IBpfMap<ClatIngress6Key, ClatIngress6Value> getBpfIngress6Map() {
-            // Pre-T devices don't use ClatCoordinator to access clat map. Since Nat464Xlat
-            // initializes a ClatCoordinator object to avoid redundant null pointer check
-            // while using, ignore the BPF map initialization on pre-T devices.
-            // TODO: probably don't initialize ClatCoordinator object on pre-T devices.
-            if (!SdkLevel.isAtLeastT()) return null;
             try {
                 return new BpfMap<>(CLAT_INGRESS6_MAP_PATH,
                     BpfMap.BPF_F_RDWR, ClatIngress6Key.class, ClatIngress6Value.class);
@@ -268,11 +267,6 @@
         /** Get egress4 BPF map. */
         @Nullable
         public IBpfMap<ClatEgress4Key, ClatEgress4Value> getBpfEgress4Map() {
-            // Pre-T devices don't use ClatCoordinator to access clat map. Since Nat464Xlat
-            // initializes a ClatCoordinator object to avoid redundant null pointer check
-            // while using, ignore the BPF map initialization on pre-T devices.
-            // TODO: probably don't initialize ClatCoordinator object on pre-T devices.
-            if (!SdkLevel.isAtLeastT()) return null;
             try {
                 return new BpfMap<>(CLAT_EGRESS4_MAP_PATH,
                     BpfMap.BPF_F_RDWR, ClatEgress4Key.class, ClatEgress4Value.class);
@@ -285,11 +279,6 @@
         /** Get cookie tag map */
         @Nullable
         public IBpfMap<CookieTagMapKey, CookieTagMapValue> getBpfCookieTagMap() {
-            // Pre-T devices don't use ClatCoordinator to access clat map. Since Nat464Xlat
-            // initializes a ClatCoordinator object to avoid redundant null pointer check
-            // while using, ignore the BPF map initialization on pre-T devices.
-            // TODO: probably don't initialize ClatCoordinator object on pre-T devices.
-            if (!SdkLevel.isAtLeastT()) return null;
             try {
                 return new BpfMap<>(COOKIE_TAG_MAP_PATH,
                         BpfMap.BPF_F_RDWR, CookieTagMapKey.class, CookieTagMapValue.class);
diff --git a/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java b/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
index 07140c4..d59d526 100644
--- a/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveStatsTracker.java
@@ -16,21 +16,46 @@
 
 package com.android.server.connectivity;
 
+import static android.telephony.SubscriptionManager.OnSubscriptionsChangedListener;
+
 import android.annotation.NonNull;
+import android.content.BroadcastReceiver;
+import android.content.Context;
+import android.content.Intent;
+import android.content.IntentFilter;
+import android.net.Network;
+import android.net.NetworkCapabilities;
+import android.net.NetworkSpecifier;
+import android.net.TelephonyNetworkSpecifier;
+import android.net.TransportInfo;
+import android.net.wifi.WifiInfo;
 import android.os.Handler;
 import android.os.SystemClock;
+import android.telephony.SubscriptionInfo;
+import android.telephony.SubscriptionManager;
+import android.telephony.TelephonyManager;
 import android.util.Log;
+import android.util.SparseArray;
+import android.util.SparseIntArray;
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.metrics.DailykeepaliveInfoReported;
 import com.android.metrics.DurationForNumOfKeepalive;
 import com.android.metrics.DurationPerNumOfKeepalive;
+import com.android.metrics.KeepaliveLifetimeForCarrier;
+import com.android.metrics.KeepaliveLifetimePerCarrier;
+import com.android.modules.utils.BackgroundThread;
+import com.android.net.module.util.CollectionUtils;
+import com.android.server.ConnectivityStatsLog;
 
 import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
+import java.util.Set;
 
-// TODO(b/273451360): Also track KeepaliveLifetimeForCarrier and DailykeepaliveInfoReported
 /**
  * Tracks carrier and duration metrics of automatic on/off keepalives.
  *
@@ -44,6 +69,111 @@
     @NonNull private final Handler mConnectivityServiceHandler;
     @NonNull private final Dependencies mDependencies;
 
+    // Mapping of subId to carrierId. Updates are received from OnSubscriptionsChangedListener
+    private final SparseIntArray mCachedCarrierIdPerSubId = new SparseIntArray();
+    // The default subscription id obtained from SubscriptionManager.getDefaultSubscriptionId.
+    // Updates are received from the ACTION_DEFAULT_SUBSCRIPTION_CHANGED broadcast.
+    private int mCachedDefaultSubscriptionId = SubscriptionManager.INVALID_SUBSCRIPTION_ID;
+
+    // Class to store network information, lifetime durations and active state of a keepalive.
+    private static final class KeepaliveStats {
+        // The carrier ID for a keepalive, or TelephonyManager.UNKNOWN_CARRIER_ID(-1) if not set.
+        public final int carrierId;
+        // The transport types of the underlying network for each keepalive. A network may include
+        // multiple transport types. Each transport type is represented by a different bit, defined
+        // in NetworkCapabilities
+        public final int transportTypes;
+        // The keepalive interval in millis.
+        public final int intervalMs;
+        // The uid of the app that requested the keepalive.
+        public final int appUid;
+        // Indicates if the keepalive is an automatic keepalive.
+        public final boolean isAutoKeepalive;
+
+        // Snapshot of the lifetime stats
+        public static class LifetimeStats {
+            public final int lifetimeMs;
+            public final int activeLifetimeMs;
+
+            LifetimeStats(int lifetimeMs, int activeLifetimeMs) {
+                this.lifetimeMs = lifetimeMs;
+                this.activeLifetimeMs = activeLifetimeMs;
+            }
+        }
+
+        // The total time since the keepalive is started until it is stopped.
+        private int mLifetimeMs = 0;
+        // The total time the keepalive is active (not suspended).
+        private int mActiveLifetimeMs = 0;
+
+        // A timestamp of the most recent time the lifetime metrics was updated.
+        private long mLastUpdateLifetimeTimestamp;
+
+        // A flag to indicate if the keepalive is active.
+        private boolean mKeepaliveActive = true;
+
+        /**
+         * Gets the lifetime stats for the keepalive, updated to timeNow, and then resets it.
+         *
+         * @param timeNow a timestamp obtained using Dependencies.getElapsedRealtime
+         */
+        public LifetimeStats getAndResetLifetimeStats(long timeNow) {
+            updateLifetimeStatsAndSetActive(timeNow, mKeepaliveActive);
+            // Get a snapshot of the stats
+            final LifetimeStats lifetimeStats = new LifetimeStats(mLifetimeMs, mActiveLifetimeMs);
+            // Reset the stats
+            resetLifetimeStats(timeNow);
+
+            return lifetimeStats;
+        }
+
+        public boolean isKeepaliveActive() {
+            return mKeepaliveActive;
+        }
+
+        KeepaliveStats(
+                int carrierId,
+                int transportTypes,
+                int intervalSeconds,
+                int appUid,
+                boolean isAutoKeepalive,
+                long timeNow) {
+            this.carrierId = carrierId;
+            this.transportTypes = transportTypes;
+            this.intervalMs = intervalSeconds * 1000;
+            this.appUid = appUid;
+            this.isAutoKeepalive = isAutoKeepalive;
+            mLastUpdateLifetimeTimestamp = timeNow;
+        }
+
+        /**
+         * Updates the lifetime metrics to the given time and sets the active state. This should be
+         * called whenever the active state of the keepalive changes.
+         *
+         * @param timeNow a timestamp obtained using Dependencies.getElapsedRealtime
+         */
+        public void updateLifetimeStatsAndSetActive(long timeNow, boolean keepaliveActive) {
+            final int durationIncrease = (int) (timeNow - mLastUpdateLifetimeTimestamp);
+            mLifetimeMs += durationIncrease;
+            if (mKeepaliveActive) mActiveLifetimeMs += durationIncrease;
+
+            mLastUpdateLifetimeTimestamp = timeNow;
+            mKeepaliveActive = keepaliveActive;
+        }
+
+        /**
+         * Resets the lifetime metrics but does not reset the active/stopped state of the keepalive.
+         * This also updates the time to timeNow, ensuring stats will start from this time.
+         *
+         * @param timeNow a timestamp obtained using Dependencies.getElapsedRealtime
+         */
+        public void resetLifetimeStats(long timeNow) {
+            mLifetimeMs = 0;
+            mActiveLifetimeMs = 0;
+            mLastUpdateLifetimeTimestamp = timeNow;
+        }
+    }
+
     // List of duration stats metric where the index is the number of concurrent keepalives.
     // Each DurationForNumOfKeepalive message stores a registered duration and an active duration.
     // Registered duration is the total time spent with mNumRegisteredKeepalive == index.
@@ -51,6 +181,62 @@
     private final List<DurationForNumOfKeepalive.Builder> mDurationPerNumOfKeepalive =
             new ArrayList<>();
 
+    // Map of keepalives identified by the id from getKeepaliveId to their stats information.
+    private final SparseArray<KeepaliveStats> mKeepaliveStatsPerId = new SparseArray<>();
+
+    // Generate a unique integer using a given network's netId and the slot number.
+    // This is possible because netId is a 16 bit integer, so an integer with the first 16 bits as
+    // the netId and the last 16 bits as the slot number can be created. This allows slot numbers to
+    // be up to 2^16.
+    private int getKeepaliveId(@NonNull Network network, int slot) {
+        final int netId = network.getNetId();
+        if (netId < 0 || netId >= (1 << 16)) {
+            throw new IllegalArgumentException("Unexpected netId value: " + netId);
+        }
+        if (slot < 0 || slot >= (1 << 16)) {
+            throw new IllegalArgumentException("Unexpected slot value: " + slot);
+        }
+
+        return (netId << 16) + slot;
+    }
+
+    // Class to act as the key to aggregate the KeepaliveLifetimeForCarrier stats.
+    private static final class LifetimeKey {
+        public final int carrierId;
+        public final int transportTypes;
+        public final int intervalMs;
+
+        LifetimeKey(int carrierId, int transportTypes, int intervalMs) {
+            this.carrierId = carrierId;
+            this.transportTypes = transportTypes;
+            this.intervalMs = intervalMs;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+
+            final LifetimeKey that = (LifetimeKey) o;
+
+            return carrierId == that.carrierId && transportTypes == that.transportTypes
+                    && intervalMs == that.intervalMs;
+        }
+
+        @Override
+        public int hashCode() {
+            return carrierId + 3 * transportTypes + 5 * intervalMs;
+        }
+    }
+
+    // Map to aggregate the KeepaliveLifetimeForCarrier stats using LifetimeKey as the key.
+    final Map<LifetimeKey, KeepaliveLifetimeForCarrier.Builder> mAggregateKeepaliveLifetime =
+            new HashMap<>();
+
+    private final Set<Integer> mAppUids = new HashSet<Integer>();
+    private int mNumKeepaliveRequests = 0;
+    private int mNumAutomaticKeepaliveRequests = 0;
+
     private int mNumRegisteredKeepalive = 0;
     private int mNumActiveKeepalive = 0;
 
@@ -60,23 +246,70 @@
     /** Dependency class */
     @VisibleForTesting
     public static class Dependencies {
-        // Returns a timestamp with the time base of SystemClock.uptimeMillis to keep durations
-        // relative to start time and avoid timezone change.
-        public long getUptimeMillis() {
-            return SystemClock.uptimeMillis();
+        // Returns a timestamp with the time base of SystemClock.elapsedRealtime to keep durations
+        // relative to start time and avoid timezone change, including time spent in deep sleep.
+        public long getElapsedRealtime() {
+            return SystemClock.elapsedRealtime();
         }
     }
 
-    public KeepaliveStatsTracker(@NonNull Handler handler) {
-        this(handler, new Dependencies());
+    public KeepaliveStatsTracker(@NonNull Context context, @NonNull Handler handler) {
+        this(context, handler, new Dependencies());
     }
 
     @VisibleForTesting
-    public KeepaliveStatsTracker(@NonNull Handler handler, @NonNull Dependencies dependencies) {
+    public KeepaliveStatsTracker(
+            @NonNull Context context,
+            @NonNull Handler handler,
+            @NonNull Dependencies dependencies) {
+        Objects.requireNonNull(context);
         mDependencies = Objects.requireNonNull(dependencies);
         mConnectivityServiceHandler = Objects.requireNonNull(handler);
 
-        mLastUpdateDurationsTimestamp = mDependencies.getUptimeMillis();
+        final SubscriptionManager subscriptionManager =
+                Objects.requireNonNull(context.getSystemService(SubscriptionManager.class));
+
+        mLastUpdateDurationsTimestamp = mDependencies.getElapsedRealtime();
+        context.registerReceiver(
+                new BroadcastReceiver() {
+                    @Override
+                    public void onReceive(Context context, Intent intent) {
+                        mCachedDefaultSubscriptionId =
+                                intent.getIntExtra(
+                                        SubscriptionManager.EXTRA_SUBSCRIPTION_INDEX,
+                                        SubscriptionManager.INVALID_SUBSCRIPTION_ID);
+                    }
+                },
+                new IntentFilter(SubscriptionManager.ACTION_DEFAULT_SUBSCRIPTION_CHANGED),
+                /* broadcastPermission= */ null,
+                mConnectivityServiceHandler);
+
+        // The default constructor for OnSubscriptionsChangedListener will always implicitly grab
+        // the looper of the current thread. In the case the current thread does not have a looper,
+        // this will throw. Therefore, post a runnable that creates it there.
+        // When the callback is called on the BackgroundThread, post a message on the CS handler
+        // thread to update the caches, which can only be touched there.
+        BackgroundThread.getHandler().post(() ->
+                subscriptionManager.addOnSubscriptionsChangedListener(
+                        r -> r.run(), new OnSubscriptionsChangedListener() {
+                            @Override
+                            public void onSubscriptionsChanged() {
+                                final List<SubscriptionInfo> activeSubInfoList =
+                                        subscriptionManager.getActiveSubscriptionInfoList();
+                                // A null subInfo list here indicates the current state is unknown
+                                // but not necessarily empty, simply ignore it. Another call to the
+                                // listener will be invoked in the future.
+                                if (activeSubInfoList == null) return;
+                                mConnectivityServiceHandler.post(() -> {
+                                    mCachedCarrierIdPerSubId.clear();
+
+                                    for (final SubscriptionInfo subInfo : activeSubInfoList) {
+                                        mCachedCarrierIdPerSubId.put(subInfo.getSubscriptionId(),
+                                                subInfo.getCarrierId());
+                                    }
+                                });
+                            }
+                        }));
     }
 
     /** Ensures the list of duration metrics is large enough for number of registered keepalives. */
@@ -106,7 +339,7 @@
      * change to mNumRegisteredKeepalive or mNumActiveKeepalive to keep the duration metrics
      * correct.
      *
-     * @param timeNow a timestamp obtained using Dependencies.getUptimeMillis
+     * @param timeNow a timestamp obtained using Dependencies.getElapsedRealtime
      */
     private void updateDurationsPerNumOfKeepalive(long timeNow) {
         if (mDurationPerNumOfKeepalive.size() < mNumRegisteredKeepalive) {
@@ -132,55 +365,201 @@
         mLastUpdateDurationsTimestamp = timeNow;
     }
 
-    /** Inform the KeepaliveStatsTracker a keepalive has just started and is active. */
-    public void onStartKeepalive() {
-        ensureRunningOnHandlerThread();
+    // TODO: Move this function to frameworks/libs/net/.../NetworkCapabilitiesUtils.java
+    private static int getSubId(@NonNull NetworkCapabilities nc, int defaultSubId) {
+        if (nc.hasTransport(NetworkCapabilities.TRANSPORT_CELLULAR)) {
+            final NetworkSpecifier networkSpecifier = nc.getNetworkSpecifier();
+            if (networkSpecifier instanceof TelephonyNetworkSpecifier) {
+                return ((TelephonyNetworkSpecifier) networkSpecifier).getSubscriptionId();
+            }
+            // Use the default subscriptionId.
+            return defaultSubId;
+        }
+        if (nc.hasTransport(NetworkCapabilities.TRANSPORT_WIFI)) {
+            final TransportInfo info = nc.getTransportInfo();
+            if (info instanceof WifiInfo) {
+                return ((WifiInfo) info).getSubscriptionId();
+            }
+        }
 
-        final long timeNow = mDependencies.getUptimeMillis();
+        return SubscriptionManager.INVALID_SUBSCRIPTION_ID;
+    }
+
+    private int getCarrierId(@NonNull NetworkCapabilities networkCapabilities) {
+        // Try to get the correct subscription id.
+        final int subId = getSubId(networkCapabilities, mCachedDefaultSubscriptionId);
+        if (subId == SubscriptionManager.INVALID_SUBSCRIPTION_ID) {
+            return TelephonyManager.UNKNOWN_CARRIER_ID;
+        }
+        return mCachedCarrierIdPerSubId.get(subId, TelephonyManager.UNKNOWN_CARRIER_ID);
+    }
+
+    private int getTransportTypes(@NonNull NetworkCapabilities networkCapabilities) {
+        // Transport types are internally packed as bits starting from bit 0. Casting to int works
+        // fine since for now and the foreseeable future, there will be less than 32 transports.
+        return (int) networkCapabilities.getTransportTypesInternal();
+    }
+
+    /** Inform the KeepaliveStatsTracker a keepalive has just started and is active. */
+    public void onStartKeepalive(
+            @NonNull Network network,
+            int slot,
+            @NonNull NetworkCapabilities nc,
+            int intervalSeconds,
+            int appUid,
+            boolean isAutoKeepalive) {
+        ensureRunningOnHandlerThread();
+        final int keepaliveId = getKeepaliveId(network, slot);
+        if (mKeepaliveStatsPerId.contains(keepaliveId)) {
+            throw new IllegalArgumentException(
+                    "Attempt to start keepalive stats on a known network, slot pair");
+        }
+
+        mNumKeepaliveRequests++;
+        if (isAutoKeepalive) mNumAutomaticKeepaliveRequests++;
+        mAppUids.add(appUid);
+
+        final long timeNow = mDependencies.getElapsedRealtime();
         updateDurationsPerNumOfKeepalive(timeNow);
 
         mNumRegisteredKeepalive++;
         mNumActiveKeepalive++;
+
+        final KeepaliveStats newKeepaliveStats =
+                new KeepaliveStats(
+                        getCarrierId(nc),
+                        getTransportTypes(nc),
+                        intervalSeconds,
+                        appUid,
+                        isAutoKeepalive,
+                        timeNow);
+
+        mKeepaliveStatsPerId.put(keepaliveId, newKeepaliveStats);
+    }
+
+    /**
+     * Inform the KeepaliveStatsTracker that the keepalive with the given network, slot pair has
+     * updated its active state to keepaliveActive.
+     *
+     * @return the KeepaliveStats associated with the network, slot pair or null if it is unknown.
+     */
+    private @NonNull KeepaliveStats onKeepaliveActive(
+            @NonNull Network network, int slot, boolean keepaliveActive) {
+        final long timeNow = mDependencies.getElapsedRealtime();
+        return onKeepaliveActive(network, slot, keepaliveActive, timeNow);
+    }
+
+    /**
+     * Inform the KeepaliveStatsTracker that the keepalive with the given network, slot pair has
+     * updated its active state to keepaliveActive.
+     *
+     * @param network the network of the keepalive
+     * @param slot the slot number of the keepalive
+     * @param keepaliveActive the new active state of the keepalive
+     * @param timeNow a timestamp obtained using Dependencies.getElapsedRealtime
+     * @return the KeepaliveStats associated with the network, slot pair or null if it is unknown.
+     */
+    private @NonNull KeepaliveStats onKeepaliveActive(
+            @NonNull Network network, int slot, boolean keepaliveActive, long timeNow) {
+        ensureRunningOnHandlerThread();
+
+        final int keepaliveId = getKeepaliveId(network, slot);
+        if (!mKeepaliveStatsPerId.contains(keepaliveId)) {
+            throw new IllegalArgumentException(
+                    "Attempt to set active keepalive on an unknown network, slot pair");
+        }
+        updateDurationsPerNumOfKeepalive(timeNow);
+
+        final KeepaliveStats keepaliveStats = mKeepaliveStatsPerId.get(keepaliveId);
+        if (keepaliveActive != keepaliveStats.isKeepaliveActive()) {
+            mNumActiveKeepalive += keepaliveActive ? 1 : -1;
+        }
+
+        keepaliveStats.updateLifetimeStatsAndSetActive(timeNow, keepaliveActive);
+        return keepaliveStats;
     }
 
     /** Inform the KeepaliveStatsTracker a keepalive has just been paused. */
-    public void onPauseKeepalive() {
-        ensureRunningOnHandlerThread();
-
-        final long timeNow = mDependencies.getUptimeMillis();
-        updateDurationsPerNumOfKeepalive(timeNow);
-
-        mNumActiveKeepalive--;
+    public void onPauseKeepalive(@NonNull Network network, int slot) {
+        onKeepaliveActive(network, slot, /* keepaliveActive= */ false);
     }
 
     /** Inform the KeepaliveStatsTracker a keepalive has just been resumed. */
-    public void onResumeKeepalive() {
-        ensureRunningOnHandlerThread();
-
-        final long timeNow = mDependencies.getUptimeMillis();
-        updateDurationsPerNumOfKeepalive(timeNow);
-
-        mNumActiveKeepalive++;
+    public void onResumeKeepalive(@NonNull Network network, int slot) {
+        onKeepaliveActive(network, slot, /* keepaliveActive= */ true);
     }
 
     /** Inform the KeepaliveStatsTracker a keepalive has just been stopped. */
-    public void onStopKeepalive(boolean wasActive) {
-        ensureRunningOnHandlerThread();
+    public void onStopKeepalive(@NonNull Network network, int slot) {
+        final int keepaliveId = getKeepaliveId(network, slot);
+        final long timeNow = mDependencies.getElapsedRealtime();
 
-        final long timeNow = mDependencies.getUptimeMillis();
-        updateDurationsPerNumOfKeepalive(timeNow);
+        final KeepaliveStats keepaliveStats =
+                onKeepaliveActive(network, slot, /* keepaliveActive= */ false, timeNow);
 
         mNumRegisteredKeepalive--;
-        if (wasActive) mNumActiveKeepalive--;
+
+        // add to the aggregate since it will be removed.
+        addToAggregateKeepaliveLifetime(keepaliveStats, timeNow);
+        // free up the slot.
+        mKeepaliveStatsPerId.remove(keepaliveId);
+    }
+
+    /**
+     * Updates and adds the lifetime metric of keepaliveStats to the aggregate.
+     *
+     * @param keepaliveStats the stats to add to the aggregate
+     * @param timeNow a timestamp obtained using Dependencies.getElapsedRealtime
+     */
+    private void addToAggregateKeepaliveLifetime(
+            @NonNull KeepaliveStats keepaliveStats, long timeNow) {
+
+        final KeepaliveStats.LifetimeStats lifetimeStats =
+                keepaliveStats.getAndResetLifetimeStats(timeNow);
+
+        final LifetimeKey key =
+                new LifetimeKey(
+                        keepaliveStats.carrierId,
+                        keepaliveStats.transportTypes,
+                        keepaliveStats.intervalMs);
+
+        KeepaliveLifetimeForCarrier.Builder keepaliveLifetimeForCarrier =
+                mAggregateKeepaliveLifetime.get(key);
+
+        if (keepaliveLifetimeForCarrier == null) {
+            keepaliveLifetimeForCarrier =
+                    KeepaliveLifetimeForCarrier.newBuilder()
+                            .setCarrierId(keepaliveStats.carrierId)
+                            .setTransportTypes(keepaliveStats.transportTypes)
+                            .setIntervalsMsec(keepaliveStats.intervalMs);
+            mAggregateKeepaliveLifetime.put(key, keepaliveLifetimeForCarrier);
+        }
+
+        keepaliveLifetimeForCarrier.setLifetimeMsec(
+                keepaliveLifetimeForCarrier.getLifetimeMsec() + lifetimeStats.lifetimeMs);
+        keepaliveLifetimeForCarrier.setActiveLifetimeMsec(
+                keepaliveLifetimeForCarrier.getActiveLifetimeMsec()
+                        + lifetimeStats.activeLifetimeMs);
     }
 
     /**
      * Builds and returns DailykeepaliveInfoReported proto.
+     *
+     * @return the DailykeepaliveInfoReported proto that was built.
      */
-    public DailykeepaliveInfoReported buildKeepaliveMetrics() {
+    @VisibleForTesting
+    public @NonNull DailykeepaliveInfoReported buildKeepaliveMetrics() {
         ensureRunningOnHandlerThread();
+        final long timeNow = mDependencies.getElapsedRealtime();
+        return buildKeepaliveMetrics(timeNow);
+    }
 
-        final long timeNow = mDependencies.getUptimeMillis();
+    /**
+     * Updates the metrics to timeNow and builds and returns DailykeepaliveInfoReported proto.
+     *
+     * @param timeNow a timestamp obtained using Dependencies.getElapsedRealtime
+     */
+    private @NonNull DailykeepaliveInfoReported buildKeepaliveMetrics(long timeNow) {
         updateDurationsPerNumOfKeepalive(timeNow);
 
         final DurationPerNumOfKeepalive.Builder durationPerNumOfKeepalive =
@@ -191,21 +570,82 @@
                         durationPerNumOfKeepalive.addDurationForNumOfKeepalive(
                                 durationForNumOfKeepalive));
 
+        final KeepaliveLifetimePerCarrier.Builder keepaliveLifetimePerCarrier =
+                KeepaliveLifetimePerCarrier.newBuilder();
+
+        for (int i = 0; i < mKeepaliveStatsPerId.size(); i++) {
+            final KeepaliveStats keepaliveStats = mKeepaliveStatsPerId.valueAt(i);
+            addToAggregateKeepaliveLifetime(keepaliveStats, timeNow);
+        }
+
+        // Fill keepalive carrier stats to the proto
+        mAggregateKeepaliveLifetime
+                .values()
+                .forEach(
+                        keepaliveLifetimeForCarrier ->
+                                keepaliveLifetimePerCarrier.addKeepaliveLifetimeForCarrier(
+                                        keepaliveLifetimeForCarrier));
+
         final DailykeepaliveInfoReported.Builder dailyKeepaliveInfoReported =
                 DailykeepaliveInfoReported.newBuilder();
 
-        // TODO(b/273451360): fill all the other values and write to ConnectivityStatsLog.
         dailyKeepaliveInfoReported.setDurationPerNumOfKeepalive(durationPerNumOfKeepalive);
+        dailyKeepaliveInfoReported.setKeepaliveLifetimePerCarrier(keepaliveLifetimePerCarrier);
+        dailyKeepaliveInfoReported.setKeepaliveRequests(mNumKeepaliveRequests);
+        dailyKeepaliveInfoReported.setAutomaticKeepaliveRequests(mNumAutomaticKeepaliveRequests);
+        dailyKeepaliveInfoReported.setDistinctUserCount(mAppUids.size());
+        dailyKeepaliveInfoReported.addAllUid(mAppUids);
 
         return dailyKeepaliveInfoReported.build();
     }
 
-    /** Resets the stored metrics but maintains the state of keepalives */
-    public void resetMetrics() {
+    /**
+     * Builds and resets the stored metrics. Similar to buildKeepaliveMetrics but also resets the
+     * metrics while maintaining the state of the keepalives.
+     *
+     * @return the DailykeepaliveInfoReported proto that was built.
+     */
+    @VisibleForTesting
+    public @NonNull DailykeepaliveInfoReported buildAndResetMetrics() {
         ensureRunningOnHandlerThread();
+        final long timeNow = mDependencies.getElapsedRealtime();
+
+        final DailykeepaliveInfoReported metrics = buildKeepaliveMetrics(timeNow);
 
         mDurationPerNumOfKeepalive.clear();
+        mAggregateKeepaliveLifetime.clear();
+        mAppUids.clear();
+        mNumKeepaliveRequests = 0;
+        mNumAutomaticKeepaliveRequests = 0;
+
+        // Update the metrics with the existing keepalives.
         ensureDurationPerNumOfKeepaliveSize();
+
+        mAggregateKeepaliveLifetime.clear();
+        // Reset the stats for existing keepalives
+        for (int i = 0; i < mKeepaliveStatsPerId.size(); i++) {
+            final KeepaliveStats keepaliveStats = mKeepaliveStatsPerId.valueAt(i);
+            keepaliveStats.resetLifetimeStats(timeNow);
+            mAppUids.add(keepaliveStats.appUid);
+            mNumKeepaliveRequests++;
+            if (keepaliveStats.isAutoKeepalive) mNumAutomaticKeepaliveRequests++;
+        }
+
+        return metrics;
+    }
+
+    /** Writes the stored metrics to ConnectivityStatsLog and resets.  */
+    public void writeAndResetMetrics() {
+        ensureRunningOnHandlerThread();
+        final DailykeepaliveInfoReported dailyKeepaliveInfoReported = buildAndResetMetrics();
+        ConnectivityStatsLog.write(
+                ConnectivityStatsLog.DAILY_KEEPALIVE_INFO_REPORTED,
+                dailyKeepaliveInfoReported.getDurationPerNumOfKeepalive().toByteArray(),
+                dailyKeepaliveInfoReported.getKeepaliveLifetimePerCarrier().toByteArray(),
+                dailyKeepaliveInfoReported.getKeepaliveRequests(),
+                dailyKeepaliveInfoReported.getAutomaticKeepaliveRequests(),
+                dailyKeepaliveInfoReported.getDistinctUserCount(),
+                CollectionUtils.toIntArray(dailyKeepaliveInfoReported.getUidList()));
     }
 
     private void ensureRunningOnHandlerThread() {
diff --git a/service/src/com/android/server/connectivity/KeepaliveTracker.java b/service/src/com/android/server/connectivity/KeepaliveTracker.java
index 1fd8a62..b4f74d5 100644
--- a/service/src/com/android/server/connectivity/KeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveTracker.java
@@ -271,6 +271,10 @@
             return mInterval;
         }
 
+        public int getUid() {
+            return mUid;
+        }
+
         private int checkNetworkConnected() {
             if (!mNai.networkInfo.isConnectedOrConnecting()) {
                 return ERROR_INVALID_NETWORK;
diff --git a/service/src/com/android/server/connectivity/Nat464Xlat.java b/service/src/com/android/server/connectivity/Nat464Xlat.java
index 2ac2ad3..90cddda 100644
--- a/service/src/com/android/server/connectivity/Nat464Xlat.java
+++ b/service/src/com/android/server/connectivity/Nat464Xlat.java
@@ -101,9 +101,9 @@
     private String mIface;
     private Inet6Address mIPv6Address;
     private State mState = State.IDLE;
-    private ClatCoordinator mClatCoordinator;
+    private final ClatCoordinator mClatCoordinator;  // non-null iff T+
 
-    private boolean mEnableClatOnCellular;
+    private final boolean mEnableClatOnCellular;
     private boolean mPrefixDiscoveryRunning;
 
     public Nat464Xlat(NetworkAgentInfo nai, INetd netd, IDnsResolver dnsResolver,
@@ -112,7 +112,11 @@
         mNetd = netd;
         mNetwork = nai;
         mEnableClatOnCellular = deps.getCellular464XlatEnabled();
-        mClatCoordinator = deps.getClatCoordinator(mNetd);
+        if (SdkLevel.isAtLeastT()) {
+            mClatCoordinator = deps.getClatCoordinator(mNetd);
+        } else {
+            mClatCoordinator = null;
+        }
     }
 
     /**
diff --git a/service/src/com/android/server/connectivity/NetworkDiagnostics.java b/service/src/com/android/server/connectivity/NetworkDiagnostics.java
index 4f80d47..a367d9d 100644
--- a/service/src/com/android/server/connectivity/NetworkDiagnostics.java
+++ b/service/src/com/android/server/connectivity/NetworkDiagnostics.java
@@ -19,6 +19,7 @@
 import static android.system.OsConstants.*;
 
 import static com.android.net.module.util.NetworkStackConstants.DNS_OVER_TLS_PORT;
+import static com.android.net.module.util.NetworkStackConstants.ETHER_MTU;
 import static com.android.net.module.util.NetworkStackConstants.ICMP_HEADER_LEN;
 import static com.android.net.module.util.NetworkStackConstants.IPV4_HEADER_MIN_LEN;
 import static com.android.net.module.util.NetworkStackConstants.IPV6_HEADER_LEN;
@@ -212,7 +213,8 @@
             mLinkProperties.addDnsServer(TEST_DNS6);
         }
 
-        final int mtu = mLinkProperties.getMtu();
+        final int lpMtu = mLinkProperties.getMtu();
+        final int mtu = lpMtu > 0 ? lpMtu : ETHER_MTU;
         for (RouteInfo route : mLinkProperties.getRoutes()) {
             if (route.getType() == RouteInfo.RTN_UNICAST && route.hasGateway()) {
                 InetAddress gateway = route.getGateway();
diff --git a/service/src/com/android/server/connectivity/TcpKeepaliveController.java b/service/src/com/android/server/connectivity/TcpKeepaliveController.java
index 0fd8604..4124e36 100644
--- a/service/src/com/android/server/connectivity/TcpKeepaliveController.java
+++ b/service/src/com/android/server/connectivity/TcpKeepaliveController.java
@@ -34,6 +34,7 @@
 import static com.android.net.module.util.NetworkStackConstants.IPV4_HEADER_MIN_LEN;
 
 import android.annotation.NonNull;
+import android.annotation.SuppressLint;
 import android.net.ISocketKeepaliveCallback;
 import android.net.InvalidPacketException;
 import android.net.NetworkUtils;
@@ -106,6 +107,8 @@
     private static final int TCP_REPAIR_ON = 1;
     // Reference include/uapi/linux/sockios.h
     private static final int SIOCINQ = FIONREAD;
+    // arch specific BSD socket API constant that predates Linux and Android
+    @SuppressLint("NewApi")
     private static final int SIOCOUTQ = TIOCOUTQ;
 
     /**
diff --git a/tests/common/java/android/net/NattKeepalivePacketDataTest.kt b/tests/common/java/android/net/NattKeepalivePacketDataTest.kt
index a7d1115..dde1d86 100644
--- a/tests/common/java/android/net/NattKeepalivePacketDataTest.kt
+++ b/tests/common/java/android/net/NattKeepalivePacketDataTest.kt
@@ -22,15 +22,16 @@
 import android.os.Build
 import androidx.test.filters.SmallTest
 import androidx.test.runner.AndroidJUnit4
+import com.android.testutils.ConnectivityModuleTest
 import com.android.testutils.DevSdkIgnoreRule
 import com.android.testutils.DevSdkIgnoreRule.IgnoreUpTo
 import com.android.testutils.assertEqualBothWays
 import com.android.testutils.assertParcelingIsLossless
 import com.android.testutils.parcelingRoundTrip
 import java.net.InetAddress
+import kotlin.test.assertFailsWith
 import org.junit.Assert.assertEquals
 import org.junit.Assert.assertNotEquals
-import org.junit.Assert.fail
 import org.junit.Rule
 import org.junit.Test
 import org.junit.runner.RunWith
@@ -46,6 +47,8 @@
     private val TEST_SRC_ADDRV4 = "198.168.0.2".address()
     private val TEST_DST_ADDRV4 = "198.168.0.1".address()
     private val TEST_ADDRV6 = "2001:db8::1".address()
+    private val TEST_ADDRV4MAPPEDV6 = "::ffff:1.2.3.4".address()
+    private val TEST_ADDRV4 = "1.2.3.4".address()
 
     private fun String.address() = InetAddresses.parseNumericAddress(this)
     private fun nattKeepalivePacket(
@@ -57,33 +60,52 @@
 
     @Test @IgnoreUpTo(Build.VERSION_CODES.Q)
     fun testConstructor() {
-        try {
+        assertFailsWith<InvalidPacketException>(
+            "Dst port is not NATT port should cause exception") {
             nattKeepalivePacket(dstPort = TEST_PORT)
-            fail("Dst port is not NATT port should cause exception")
-        } catch (e: InvalidPacketException) {
-            assertEquals(e.error, ERROR_INVALID_PORT)
+        }.let {
+            assertEquals(it.error, ERROR_INVALID_PORT)
         }
 
-        try {
+        assertFailsWith<InvalidPacketException>("A v6 srcAddress should cause exception") {
             nattKeepalivePacket(srcAddress = TEST_ADDRV6)
-            fail("A v6 srcAddress should cause exception")
-        } catch (e: InvalidPacketException) {
-            assertEquals(e.error, ERROR_INVALID_IP_ADDRESS)
+        }.let {
+            assertEquals(it.error, ERROR_INVALID_IP_ADDRESS)
         }
 
-        try {
+        assertFailsWith<InvalidPacketException>("A v6 dstAddress should cause exception") {
             nattKeepalivePacket(dstAddress = TEST_ADDRV6)
-            fail("A v6 dstAddress should cause exception")
-        } catch (e: InvalidPacketException) {
-            assertEquals(e.error, ERROR_INVALID_IP_ADDRESS)
+        }.let {
+            assertEquals(it.error, ERROR_INVALID_IP_ADDRESS)
         }
 
-        try {
+        assertFailsWith<IllegalArgumentException>("Invalid data should cause exception") {
             parcelingRoundTrip(
-                    NattKeepalivePacketData(TEST_SRC_ADDRV4, TEST_PORT, TEST_DST_ADDRV4, TEST_PORT,
+                NattKeepalivePacketData(TEST_SRC_ADDRV4, TEST_PORT, TEST_DST_ADDRV4, TEST_PORT,
                     byteArrayOf(12, 31, 22, 44)))
-            fail("Invalid data should cause exception")
-        } catch (e: IllegalArgumentException) { }
+        }
+    }
+
+    @Test @IgnoreUpTo(Build.VERSION_CODES.R) @ConnectivityModuleTest
+    fun testConstructor_afterR() {
+        // v4 mapped v6 will be translated to a v4 address.
+        assertFailsWith<InvalidPacketException> {
+            nattKeepalivePacket(srcAddress = TEST_ADDRV6, dstAddress = TEST_ADDRV4MAPPEDV6)
+        }
+        assertFailsWith<InvalidPacketException> {
+            nattKeepalivePacket(srcAddress = TEST_ADDRV4MAPPEDV6, dstAddress = TEST_ADDRV6)
+        }
+
+        // Both src and dst address will be v4 after translation, so it won't cause exception.
+        val packet1 = nattKeepalivePacket(
+            dstAddress = TEST_ADDRV4MAPPEDV6, srcAddress = TEST_ADDRV4MAPPEDV6)
+        assertEquals(TEST_ADDRV4, packet1.srcAddress)
+        assertEquals(TEST_ADDRV4, packet1.dstAddress)
+
+        // Packet with v6 src and v6 dst address is valid.
+        val packet2 = nattKeepalivePacket(srcAddress = TEST_ADDRV6, dstAddress = TEST_ADDRV6)
+        assertEquals(TEST_ADDRV6, packet2.srcAddress)
+        assertEquals(TEST_ADDRV6, packet2.dstAddress)
     }
 
     @Test @IgnoreUpTo(Build.VERSION_CODES.Q)
diff --git a/tests/cts/net/jni/NativeMultinetworkJni.cpp b/tests/cts/net/jni/NativeMultinetworkJni.cpp
index 6610d10..f2214a3 100644
--- a/tests/cts/net/jni/NativeMultinetworkJni.cpp
+++ b/tests/cts/net/jni/NativeMultinetworkJni.cpp
@@ -42,11 +42,14 @@
 
 // Since the tests in this file commonly pass expression statements as parameters to these macros,
 // get the returned value of the statements to avoid statement double-called.
+// By checking ExceptionCheck(), these macros don't throw another exception if an exception has
+// been thrown, because ART's JNI disallows to throw another exception while an exception is
+// pending (See CheckThread in check_jni.cc).
 #define EXPECT_GE(env, actual_stmt, expected_stmt, msg)              \
     do {                                                             \
         const auto expected = (expected_stmt);                       \
         const auto actual = (actual_stmt);                           \
-        if (actual < expected) {                                     \
+        if (actual < expected && !env->ExceptionCheck()) {           \
             jniThrowExceptionFmt(env, "java/lang/AssertionError",    \
                     "%s:%d: %s EXPECT_GE: expected %d, got %d",      \
                     __FILE__, __LINE__, msg, expected, actual);      \
@@ -57,7 +60,7 @@
     do {                                                             \
         const auto expected = (expected_stmt);                       \
         const auto actual = (actual_stmt);                           \
-        if (actual <= expected) {                                    \
+        if (actual <= expected && !env->ExceptionCheck()) {          \
             jniThrowExceptionFmt(env, "java/lang/AssertionError",    \
                     "%s:%d: %s EXPECT_GT: expected %d, got %d",      \
                     __FILE__, __LINE__, msg, expected, actual);      \
@@ -68,7 +71,7 @@
     do {                                                             \
         const auto expected = (expected_stmt);                       \
         const auto actual = (actual_stmt);                           \
-        if (actual != expected) {                                    \
+        if (actual != expected && !env->ExceptionCheck()) {          \
             jniThrowExceptionFmt(env, "java/lang/AssertionError",    \
                     "%s:%d: %s EXPECT_EQ: expected %d, got %d",      \
                     __FILE__, __LINE__, msg, expected, actual);      \
diff --git a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
index ee2f6bb..1411a37 100644
--- a/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/ConnectivityManagerTest.java
@@ -3410,6 +3410,7 @@
 
     private void checkFirewallBlocking(final DatagramSocket srcSock, final DatagramSocket dstSock,
             final boolean expectBlock, final int chain) throws Exception {
+        final int uid = Process.myUid();
         final Random random = new Random();
         final byte[] sendData = new byte[100];
         random.nextBytes(sendData);
@@ -3425,7 +3426,8 @@
             fail("Expect not to be blocked by firewall but sending packet was blocked:"
                     + " chain=" + chain
                     + " chainEnabled=" + mCm.getFirewallChainEnabled(chain)
-                    + " uidFirewallRule=" + mCm.getUidFirewallRule(chain, Process.myUid()));
+                    + " uid=" + uid
+                    + " uidFirewallRule=" + mCm.getUidFirewallRule(chain, uid));
         }
 
         dstSock.receive(pkt);
@@ -3435,7 +3437,8 @@
             fail("Expect to be blocked by firewall but sending packet was not blocked:"
                     + " chain=" + chain
                     + " chainEnabled=" + mCm.getFirewallChainEnabled(chain)
-                    + " uidFirewallRule=" + mCm.getUidFirewallRule(chain, Process.myUid()));
+                    + " uid=" + uid
+                    + " uidFirewallRule=" + mCm.getUidFirewallRule(chain, uid));
         }
     }
 
diff --git a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
index cf5fc50..9f8a05d 100644
--- a/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
+++ b/tests/cts/net/src/android/net/cts/NetworkAgentTest.kt
@@ -1247,15 +1247,15 @@
 
         // Connect a third network. Because network1 is awaiting replacement, network3 is preferred
         // as soon as it validates (until then, it is outscored by network1).
-        // The fact that the first events seen by matchAllCallback is the connection of network3
+        // The fact that the first event seen by matchAllCallback is the connection of network3
         // implicitly ensures that no callbacks are sent since network1 was lost.
         val (agent3, network3) = connectNetwork()
-        matchAllCallback.expectAvailableThenValidatedCallbacks(network3)
-        testCallback.expectAvailableDoubleValidatedCallbacks(network3)
-
         // As soon as the replacement arrives, network1 is disconnected.
         // Check that this happens before the replacement timeout (5 seconds) fires.
+        matchAllCallback.expectAvailableCallbacks(network3, validated = false)
         matchAllCallback.expect<Lost>(network1, 2_000 /* timeoutMs */)
+        matchAllCallback.expectCaps(network3) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
+        testCallback.expectAvailableDoubleValidatedCallbacks(network3)
         agent1.expectCallback<OnNetworkUnwanted>()
 
         // Test lingering:
@@ -1301,7 +1301,7 @@
         val callback = TestableNetworkCallback()
         requestNetwork(makeTestNetworkRequest(specifier = specifier6), callback)
         val agent6 = createNetworkAgent(specifier = specifier6)
-        val network6 = agent6.register()
+        agent6.register()
         if (SdkLevel.isAtLeastU()) {
             agent6.expectCallback<OnNetworkCreated>()
         } else {
@@ -1368,8 +1368,9 @@
 
         val (newWifiAgent, newWifiNetwork) = connectNetwork(TRANSPORT_WIFI)
         testCallback.expectAvailableCallbacks(newWifiNetwork, validated = true)
-        matchAllCallback.expectAvailableThenValidatedCallbacks(newWifiNetwork)
+        matchAllCallback.expectAvailableCallbacks(newWifiNetwork, validated = false)
         matchAllCallback.expect<Lost>(wifiNetwork)
+        matchAllCallback.expectCaps(newWifiNetwork) { it.hasCapability(NET_CAPABILITY_VALIDATED) }
         wifiAgent.expectCallback<OnNetworkUnwanted>()
     }
 
diff --git a/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java b/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
index ce789fc..21f1358 100644
--- a/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
+++ b/tests/cts/net/util/java/android/net/cts/util/CtsNetUtils.java
@@ -86,7 +86,7 @@
     private static final String FEATURE_IPSEC_TUNNEL_MIGRATION =
             "android.software.ipsec_tunnel_migration";
 
-    private static final int SOCKET_TIMEOUT_MS = 2000;
+    private static final int SOCKET_TIMEOUT_MS = 10_000;
     private static final int PRIVATE_DNS_PROBE_MS = 1_000;
 
     private static final int PRIVATE_DNS_SETTING_TIMEOUT_MS = 10_000;
diff --git a/tests/cts/netpermission/internetpermission/Android.bp b/tests/cts/netpermission/internetpermission/Android.bp
index 37ad7cb..5314396 100644
--- a/tests/cts/netpermission/internetpermission/Android.bp
+++ b/tests/cts/netpermission/internetpermission/Android.bp
@@ -29,5 +29,5 @@
         "cts",
         "general-tests",
     ],
-
+    host_required: ["net-tests-utils-host-common"],
 }
diff --git a/tests/cts/netpermission/internetpermission/AndroidTest.xml b/tests/cts/netpermission/internetpermission/AndroidTest.xml
index 3b23e72..e326844 100644
--- a/tests/cts/netpermission/internetpermission/AndroidTest.xml
+++ b/tests/cts/netpermission/internetpermission/AndroidTest.xml
@@ -24,6 +24,8 @@
         <option name="cleanup-apks" value="true" />
         <option name="test-file-name" value="CtsNetTestCasesInternetPermission.apk" />
     </target_preparer>
+    <target_preparer class="com.android.testutils.ConnectivityTestTargetPreparer">
+    </target_preparer>
     <test class="com.android.tradefed.testtype.AndroidJUnitTest" >
         <option name="package" value="android.networkpermission.internetpermission.cts" />
         <option name="runtime-hint" value="10s" />
diff --git a/tests/cts/netpermission/updatestatspermission/Android.bp b/tests/cts/netpermission/updatestatspermission/Android.bp
index 7a24886..40474db 100644
--- a/tests/cts/netpermission/updatestatspermission/Android.bp
+++ b/tests/cts/netpermission/updatestatspermission/Android.bp
@@ -29,5 +29,5 @@
         "cts",
         "general-tests",
     ],
-
+    host_required: ["net-tests-utils-host-common"],
 }
diff --git a/tests/cts/netpermission/updatestatspermission/AndroidTest.xml b/tests/cts/netpermission/updatestatspermission/AndroidTest.xml
index c47cad9..a1019fa 100644
--- a/tests/cts/netpermission/updatestatspermission/AndroidTest.xml
+++ b/tests/cts/netpermission/updatestatspermission/AndroidTest.xml
@@ -24,6 +24,8 @@
         <option name="cleanup-apks" value="true" />
         <option name="test-file-name" value="CtsNetTestCasesUpdateStatsPermission.apk" />
     </target_preparer>
+    <target_preparer class="com.android.testutils.ConnectivityTestTargetPreparer">
+    </target_preparer>
     <test class="com.android.tradefed.testtype.AndroidJUnitTest" >
         <option name="package" value="android.networkpermission.updatestatspermission.cts" />
         <option name="runtime-hint" value="10s" />
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 43c6225..644910c 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -355,6 +355,7 @@
 import android.provider.Settings;
 import android.security.Credentials;
 import android.system.Os;
+import android.telephony.SubscriptionManager;
 import android.telephony.TelephonyManager;
 import android.telephony.data.EpsBearerQosSessionAttributes;
 import android.telephony.data.NrQosSessionAttributes;
@@ -616,6 +617,7 @@
     @Mock BroadcastOptionsShim mBroadcastOptionsShim;
     @Mock ActivityManager mActivityManager;
     @Mock DestroySocketsWrapper mDestroySocketsWrapper;
+    @Mock SubscriptionManager mSubscriptionManager;
 
     // BatteryStatsManager is final and cannot be mocked with regular mockito, so just mock the
     // underlying binder calls.
@@ -740,6 +742,7 @@
             if (Context.PAC_PROXY_SERVICE.equals(name)) return mPacProxyManager;
             if (Context.TETHERING_SERVICE.equals(name)) return mTetheringManager;
             if (Context.ACTIVITY_SERVICE.equals(name)) return mActivityManager;
+            if (Context.TELEPHONY_SUBSCRIPTION_SERVICE.equals(name)) return mSubscriptionManager;
             return super.getSystemService(name);
         }
 
@@ -2943,22 +2946,24 @@
         if (expectLingering) {
             generalCb.expectLosing(net1);
         }
-        generalCb.expectCaps(net2, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
-        defaultCb.expectAvailableDoubleValidatedCallbacks(net2);
 
         // Make sure cell 1 is unwanted immediately if the radio can't time share, but only
         // after some delay if it can.
         if (expectLingering) {
+            generalCb.expectCaps(net2, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
+            defaultCb.expectAvailableDoubleValidatedCallbacks(net2);
             net1.assertNotDisconnected(TEST_CALLBACK_TIMEOUT_MS); // always incurs the timeout
             generalCb.assertNoCallback();
             // assertNotDisconnected waited for TEST_CALLBACK_TIMEOUT_MS, so waiting for the
             // linger period gives TEST_CALLBACK_TIMEOUT_MS time for the event to process.
             net1.expectDisconnected(UNREASONABLY_LONG_ALARM_WAIT_MS);
+            generalCb.expect(LOST, net1);
         } else {
             net1.expectDisconnected(TEST_CALLBACK_TIMEOUT_MS);
+            generalCb.expect(LOST, net1);
+            generalCb.expectCaps(net2, c -> c.hasCapability(NET_CAPABILITY_VALIDATED));
+            defaultCb.expectAvailableDoubleValidatedCallbacks(net2);
         }
-        net1.disconnect();
-        generalCb.expect(LOST, net1);
 
         // Remove primary from net 2
         net2.setScore(new NetworkScore.Builder().build());
@@ -6853,10 +6858,6 @@
         ka = mCm.startNattKeepalive(myNet, validKaInterval, callback, myIPv6, 1234, dstIPv4);
         callback.expectError(PacketKeepalive.ERROR_INVALID_IP_ADDRESS);
 
-        // NAT-T is only supported for IPv4.
-        ka = mCm.startNattKeepalive(myNet, validKaInterval, callback, myIPv6, 1234, dstIPv6);
-        callback.expectError(PacketKeepalive.ERROR_INVALID_IP_ADDRESS);
-
         ka = mCm.startNattKeepalive(myNet, validKaInterval, callback, myIPv4, 123456, dstIPv4);
         callback.expectError(PacketKeepalive.ERROR_INVALID_PORT);
 
@@ -7007,13 +7008,6 @@
             callback.expectError(SocketKeepalive.ERROR_INVALID_IP_ADDRESS);
         }
 
-        // NAT-T is only supported for IPv4.
-        try (SocketKeepalive ka = mCm.createSocketKeepalive(
-                myNet, testSocket, myIPv6, dstIPv6, executor, callback)) {
-            ka.start(validKaInterval);
-            callback.expectError(SocketKeepalive.ERROR_INVALID_IP_ADDRESS);
-        }
-
         // Basic check before testing started keepalive.
         try (SocketKeepalive ka = mCm.createSocketKeepalive(
                 myNet, testSocket, myIPv4, dstIPv4, executor, callback)) {
diff --git a/tests/unit/java/com/android/server/VpnManagerServiceTest.java b/tests/unit/java/com/android/server/VpnManagerServiceTest.java
index deb56ef..bf23cd1 100644
--- a/tests/unit/java/com/android/server/VpnManagerServiceTest.java
+++ b/tests/unit/java/com/android/server/VpnManagerServiceTest.java
@@ -75,12 +75,15 @@
 @IgnoreUpTo(R) // VpnManagerService is not available before R
 @SmallTest
 public class VpnManagerServiceTest extends VpnTestBase {
+    private static final String CONTEXT_ATTRIBUTION_TAG = "VPN_MANAGER";
+
     @Rule
     public final DevSdkIgnoreRule mIgnoreRule = new DevSdkIgnoreRule();
 
     private static final int TIMEOUT_MS = 2_000;
 
     @Mock Context mContext;
+    @Mock Context mContextWithoutAttributionTag;
     @Mock Context mSystemContext;
     @Mock Context mUserAllContext;
     private HandlerThread mHandlerThread;
@@ -144,6 +147,13 @@
 
         mHandlerThread = new HandlerThread("TestVpnManagerService");
         mDeps = new VpnManagerServiceDependencies();
+
+        // The attribution tag is a dependency for IKE library to collect VPN metrics correctly
+        // and thus should not be changed without updating the IKE code.
+        doReturn(mContext)
+                .when(mContextWithoutAttributionTag)
+                .createAttributionContext(CONTEXT_ATTRIBUTION_TAG);
+
         doReturn(mUserAllContext).when(mContext).createContextAsUser(UserHandle.ALL, 0);
         doReturn(mSystemContext).when(mContext).createContextAsUser(UserHandle.SYSTEM, 0);
         doReturn(mPackageManager).when(mContext).getPackageManager();
@@ -153,7 +163,7 @@
         mockService(mContext, UserManager.class, Context.USER_SERVICE, mUserManager);
         doReturn(SYSTEM_USER).when(mUserManager).getUserInfo(eq(SYSTEM_USER_ID));
 
-        mService = new VpnManagerService(mContext, mDeps);
+        mService = new VpnManagerService(mContextWithoutAttributionTag, mDeps);
         mService.systemReady();
 
         final ArgumentCaptor<BroadcastReceiver> intentReceiverCaptor =
diff --git a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
index db65c2b..8232658 100644
--- a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
@@ -21,6 +21,7 @@
 import static android.net.NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE;
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 
+import static com.android.server.connectivity.AutomaticOnOffKeepaliveTracker.METRICS_COLLECTION_DURATION_MS;
 import static com.android.testutils.HandlerUtils.visibleOnHandlerThread;
 
 import static org.junit.Assert.assertEquals;
@@ -36,11 +37,13 @@
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.longThat;
 import static org.mockito.Mockito.clearInvocations;
+import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.ignoreStubs;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 
@@ -67,6 +70,7 @@
 import android.os.Looper;
 import android.os.Message;
 import android.os.SystemClock;
+import android.telephony.SubscriptionManager;
 import android.test.suitebuilder.annotation.SmallTest;
 import android.util.Log;
 
@@ -121,7 +125,9 @@
     @Mock Context mCtx;
     @Mock AlarmManager mAlarmManager;
     @Mock NetworkAgentInfo mNai;
+    @Mock SubscriptionManager mSubscriptionManager;
 
+    KeepaliveStatsTracker mKeepaliveStatsTracker;
     TestKeepaliveTracker mKeepaliveTracker;
     AOOTestHandler mTestHandler;
     TestTcpKeepaliveController mTcpController;
@@ -298,10 +304,22 @@
         }
     }
 
+    private <T> void mockService(String serviceName, Class<T> serviceClass, T service) {
+        doReturn(serviceName).when(mCtx).getSystemServiceName(serviceClass);
+        doReturn(service).when(mCtx).getSystemService(serviceName);
+        if (mCtx.getSystemService(serviceClass) == null) {
+            // Test is using mockito-extended
+            doCallRealMethod().when(mCtx).getSystemService(serviceClass);
+        }
+    }
+
     @Before
     public void setup() throws Exception {
         MockitoAnnotations.initMocks(this);
 
+        mockService(Context.TELEPHONY_SUBSCRIPTION_SERVICE, SubscriptionManager.class,
+                mSubscriptionManager);
+
         mNai.networkCapabilities =
                 new NetworkCapabilities.Builder().addTransportType(TRANSPORT_CELLULAR).build();
         mNai.networkInfo = new NetworkInfo(TYPE_MOBILE, 0 /* subtype */, "LTE", "LTE");
@@ -329,8 +347,14 @@
         mTestHandler = new AOOTestHandler(mHandlerThread.getLooper());
         mTcpController = new TestTcpKeepaliveController(mTestHandler);
         mKeepaliveTracker = new TestKeepaliveTracker(mCtx, mTestHandler, mTcpController);
+        mKeepaliveStatsTracker = spy(new KeepaliveStatsTracker(mCtx, mTestHandler));
         doReturn(mKeepaliveTracker).when(mDependencies).newKeepaliveTracker(mCtx, mTestHandler);
+        doReturn(mKeepaliveStatsTracker)
+                .when(mDependencies)
+                .newKeepaliveStatsTracker(mCtx, mTestHandler);
+
         doReturn(true).when(mDependencies).isFeatureEnabled(any(), anyBoolean());
+        doReturn(0L).when(mDependencies).getElapsedRealtime();
         mAOOKeepaliveTracker =
                 new AutomaticOnOffKeepaliveTracker(mCtx, mTestHandler, mDependencies);
     }
@@ -484,6 +508,30 @@
         assertEquals(testInfo.underpinnedNetwork, mTestHandler.mLastAutoKi.getUnderpinnedNetwork());
     }
 
+    @Test
+    public void testAlarm_writeMetrics() throws Exception {
+        final ArgumentCaptor<AlarmManager.OnAlarmListener> listenerCaptor =
+                ArgumentCaptor.forClass(AlarmManager.OnAlarmListener.class);
+
+        // First AlarmManager.set call from the constructor.
+        verify(mAlarmManager).set(eq(AlarmManager.ELAPSED_REALTIME_WAKEUP),
+                eq(METRICS_COLLECTION_DURATION_MS), any() /* tag */, listenerCaptor.capture(),
+                eq(mTestHandler));
+
+        final AlarmManager.OnAlarmListener listener = listenerCaptor.getValue();
+
+        doReturn(METRICS_COLLECTION_DURATION_MS).when(mDependencies).getElapsedRealtime();
+        // For realism, the listener should be posted on the handler
+        mTestHandler.post(() -> listener.onAlarm());
+        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+
+        verify(mKeepaliveStatsTracker).writeAndResetMetrics();
+        // Alarm is rescheduled.
+        verify(mAlarmManager).set(eq(AlarmManager.ELAPSED_REALTIME_WAKEUP),
+                eq(METRICS_COLLECTION_DURATION_MS * 2),
+                any() /* tag */, listenerCaptor.capture(), eq(mTestHandler));
+    }
+
     private void setupResponseWithSocketExisting() throws Exception {
         final ByteBuffer tcpBufferV6 = getByteBuffer(TEST_RESPONSE_BYTES);
         final ByteBuffer tcpBufferV4 = getByteBuffer(TEST_RESPONSE_BYTES);
@@ -772,41 +820,36 @@
 
         clearInvocations(mNai);
         // Start the second keepalive while the first is paused.
-        final TestKeepaliveInfo testInfo2 = doStartNattKeepalive();
-        // The slot used is TEST_SLOT since it is now a free slot.
-        checkAndProcessKeepaliveStart(TEST_SLOT, testInfo2.kpd);
-        verify(testInfo2.socketKeepaliveCallback).onStarted();
-        assertNotNull(getAutoKiForBinder(testInfo2.binder));
+        // TODO: Uncomment the following test after fixing b/283886067. Currently this attempts to
+        // start the keepalive on TEST_SLOT and this throws in the handler thread.
+        // final TestKeepaliveInfo testInfo2 = doStartNattKeepalive();
+        // // The slot used is TEST_SLOT + 1 since TEST_SLOT is being taken by the paused keepalive.
+        // checkAndProcessKeepaliveStart(TEST_SLOT + 1, testInfo2.kpd);
+        // verify(testInfo2.socketKeepaliveCallback).onStarted();
+        // assertNotNull(getAutoKiForBinder(testInfo2.binder));
 
-        clearInvocations(mNai);
-        doResumeKeepalive(autoKi1);
-        // The next free slot is TEST_SLOT + 1.
-        checkAndProcessKeepaliveStart(TEST_SLOT + 1, testInfo1.kpd);
-        verify(testInfo1.socketKeepaliveCallback).onResumed();
+        // clearInvocations(mNai);
+        // doResumeKeepalive(autoKi1);
+        // // Resume on TEST_SLOT.
+        // checkAndProcessKeepaliveStart(TEST_SLOT, testInfo1.kpd);
+        // verify(testInfo1.socketKeepaliveCallback).onResumed();
 
-        clearInvocations(mNai);
-        doStopKeepalive(autoKi1);
-        // TODO: The slot should be consistent with the checkAndProcessKeepaliveStart directly above
-        checkAndProcessKeepaliveStop(TEST_SLOT);
-        // TODO: onStopped should only be called on the first keepalive callback.
-        verify(testInfo1.socketKeepaliveCallback, never()).onStopped();
-        verify(testInfo2.socketKeepaliveCallback).onStopped();
-        assertNull(getAutoKiForBinder(testInfo1.binder));
+        // clearInvocations(mNai);
+        // doStopKeepalive(autoKi1);
+        // checkAndProcessKeepaliveStop(TEST_SLOT);
+        // verify(testInfo1.socketKeepaliveCallback).onStopped();
+        // verify(testInfo2.socketKeepaliveCallback, never()).onStopped();
+        // assertNull(getAutoKiForBinder(testInfo1.binder));
 
-        clearInvocations(mNai);
-        assertNotNull(getAutoKiForBinder(testInfo2.binder));
-        doStopKeepalive(getAutoKiForBinder(testInfo2.binder));
-        // This slot should be consistent with its corresponding checkAndProcessKeepaliveStart.
-        // TODO: checkAndProcessKeepaliveStop should be called instead but the keepalive is
-        // unexpectedly already stopped above.
-        verify(mNai, never()).onStopSocketKeepalive(TEST_SLOT);
-        verify(mNai, never()).onRemoveKeepalivePacketFilter(TEST_SLOT);
+        // clearInvocations(mNai);
+        // assertNotNull(getAutoKiForBinder(testInfo2.binder));
+        // doStopKeepalive(getAutoKiForBinder(testInfo2.binder));
+        // checkAndProcessKeepaliveStop(TEST_SLOT + 1);
+        // verify(testInfo2.socketKeepaliveCallback).onStopped();
+        // assertNull(getAutoKiForBinder(testInfo2.binder));
 
-        verify(testInfo2.socketKeepaliveCallback).onStopped();
-        assertNull(getAutoKiForBinder(testInfo2.binder));
-
-        verifyNoMoreInteractions(ignoreStubs(testInfo1.socketKeepaliveCallback));
-        verifyNoMoreInteractions(ignoreStubs(testInfo2.socketKeepaliveCallback));
+        // verifyNoMoreInteractions(ignoreStubs(testInfo1.socketKeepaliveCallback));
+        // verifyNoMoreInteractions(ignoreStubs(testInfo2.socketKeepaliveCallback));
     }
 
     @Test
diff --git a/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java b/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java
index 2e9bf26..0d2e540 100644
--- a/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/KeepaliveStatsTrackerTest.java
@@ -16,112 +16,298 @@
 
 package com.android.server.connectivity;
 
+import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
+import static android.net.NetworkCapabilities.TRANSPORT_WIFI;
+
 import static com.android.testutils.HandlerUtils.visibleOnHandlerThread;
 
+import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThrows;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doCallRealMethod;
 import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 
+import android.content.BroadcastReceiver;
+import android.content.Context;
+import android.content.Intent;
+import android.net.Network;
+import android.net.NetworkCapabilities;
+import android.net.TelephonyNetworkSpecifier;
+import android.net.wifi.WifiInfo;
 import android.os.Build;
 import android.os.Handler;
 import android.os.HandlerThread;
+import android.telephony.SubscriptionInfo;
+import android.telephony.SubscriptionManager;
+import android.telephony.SubscriptionManager.OnSubscriptionsChangedListener;
+import android.telephony.TelephonyManager;
 
 import androidx.test.filters.SmallTest;
 
 import com.android.metrics.DailykeepaliveInfoReported;
 import com.android.metrics.DurationForNumOfKeepalive;
 import com.android.metrics.DurationPerNumOfKeepalive;
+import com.android.metrics.KeepaliveLifetimeForCarrier;
+import com.android.metrics.KeepaliveLifetimePerCarrier;
+import com.android.modules.utils.BackgroundThread;
+import com.android.net.module.util.CollectionUtils;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRunner;
+import com.android.testutils.HandlerUtils;
 
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
 @RunWith(DevSdkIgnoreRunner.class)
 @SmallTest
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
 public class KeepaliveStatsTrackerTest {
+    private static final int TIMEOUT_MS = 30_000;
+
+    private static final int TEST_SLOT = 1;
+    private static final int TEST_SLOT2 = 2;
+    private static final int TEST_KEEPALIVE_INTERVAL_SEC = 10;
+    private static final int TEST_KEEPALIVE_INTERVAL2_SEC = 20;
+    private static final int TEST_SUB_ID_1 = 1;
+    private static final int TEST_SUB_ID_2 = 2;
+    private static final int TEST_CARRIER_ID_1 = 135;
+    private static final int TEST_CARRIER_ID_2 = 246;
+    private static final Network TEST_NETWORK = new Network(123);
+    private static final NetworkCapabilities TEST_NETWORK_CAPABILITIES =
+            buildCellNetworkCapabilitiesWithSubId(TEST_SUB_ID_1);
+    private static final NetworkCapabilities TEST_NETWORK_CAPABILITIES_2 =
+            buildCellNetworkCapabilitiesWithSubId(TEST_SUB_ID_2);
+    private static final int TEST_UID = 1234;
+
+    private static NetworkCapabilities buildCellNetworkCapabilitiesWithSubId(int subId) {
+        final TelephonyNetworkSpecifier telephonyNetworkSpecifier =
+                new TelephonyNetworkSpecifier.Builder().setSubscriptionId(subId).build();
+        return new NetworkCapabilities.Builder()
+                .addTransportType(NetworkCapabilities.TRANSPORT_CELLULAR)
+                .setNetworkSpecifier(telephonyNetworkSpecifier)
+                .build();
+    }
+
     private HandlerThread mHandlerThread;
     private Handler mTestHandler;
 
     private KeepaliveStatsTracker mKeepaliveStatsTracker;
 
+    @Mock private Context mContext;
     @Mock private KeepaliveStatsTracker.Dependencies mDependencies;
+    @Mock private SubscriptionManager mSubscriptionManager;
+
+    private void triggerBroadcastDefaultSubId(int subId) {
+        final ArgumentCaptor<BroadcastReceiver> receiverCaptor =
+                ArgumentCaptor.forClass(BroadcastReceiver.class);
+        verify(mContext).registerReceiver(receiverCaptor.capture(), /* filter= */ any(),
+                /* broadcastPermission= */ any(), eq(mTestHandler));
+        final Intent intent =
+                new Intent(TelephonyManager.ACTION_SUBSCRIPTION_CARRIER_IDENTITY_CHANGED);
+        intent.putExtra(SubscriptionManager.EXTRA_SUBSCRIPTION_INDEX, subId);
+
+        receiverCaptor.getValue().onReceive(mContext, intent);
+    }
+
+    private OnSubscriptionsChangedListener getOnSubscriptionsChangedListener() {
+        final ArgumentCaptor<OnSubscriptionsChangedListener> listenerCaptor =
+                ArgumentCaptor.forClass(OnSubscriptionsChangedListener.class);
+        verify(mSubscriptionManager)
+                .addOnSubscriptionsChangedListener(any(), listenerCaptor.capture());
+        return listenerCaptor.getValue();
+    }
+
+    private static final class KeepaliveCarrierStats {
+        public final int carrierId;
+        public final int transportTypes;
+        public final int intervalMs;
+        public final int lifetimeMs;
+        public final int activeLifetimeMs;
+
+        KeepaliveCarrierStats(
+                int carrierId,
+                int transportTypes,
+                int intervalMs,
+                int lifetimeMs,
+                int activeLifetimeMs) {
+            this.carrierId = carrierId;
+            this.transportTypes = transportTypes;
+            this.intervalMs = intervalMs;
+            this.lifetimeMs = lifetimeMs;
+            this.activeLifetimeMs = activeLifetimeMs;
+        }
+
+        // Equals method on only the key, (carrierId, tranportTypes, intervalMs)
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+
+            final KeepaliveCarrierStats that = (KeepaliveCarrierStats) o;
+
+            return carrierId == that.carrierId && transportTypes == that.transportTypes
+                    && intervalMs == that.intervalMs;
+        }
+
+        @Override
+        public int hashCode() {
+            return carrierId + 3 * transportTypes + 5 * intervalMs;
+        }
+    }
+
+    // Use the default test carrier id, transportType and keepalive interval.
+    private KeepaliveCarrierStats getDefaultCarrierStats(int lifetimeMs, int activeLifetimeMs) {
+        return new KeepaliveCarrierStats(
+                TEST_CARRIER_ID_1,
+                /* transportTypes= */ (1 << TRANSPORT_CELLULAR),
+                TEST_KEEPALIVE_INTERVAL_SEC * 1000,
+                lifetimeMs,
+                activeLifetimeMs);
+    }
+
+    private <T> void mockService(String serviceName, Class<T> serviceClass, T service) {
+        doReturn(serviceName).when(mContext).getSystemServiceName(serviceClass);
+        doReturn(service).when(mContext).getSystemService(serviceName);
+        if (mContext.getSystemService(serviceClass) == null) {
+            // Test is using mockito-extended
+            doCallRealMethod().when(mContext).getSystemService(serviceClass);
+        }
+    }
+
+    private SubscriptionInfo makeSubInfoMock(int subId, int carrierId) {
+        final SubscriptionInfo subInfo = mock(SubscriptionInfo.class);
+        doReturn(subId).when(subInfo).getSubscriptionId();
+        doReturn(carrierId).when(subInfo).getCarrierId();
+        return subInfo;
+    }
 
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
+        mockService(Context.TELEPHONY_SUBSCRIPTION_SERVICE, SubscriptionManager.class,
+                mSubscriptionManager);
+
+        final SubscriptionInfo subInfo1 = makeSubInfoMock(TEST_SUB_ID_1, TEST_CARRIER_ID_1);
+        final SubscriptionInfo subInfo2 = makeSubInfoMock(TEST_SUB_ID_2, TEST_CARRIER_ID_2);
+
+        doReturn(List.of(subInfo1, subInfo2))
+                .when(mSubscriptionManager)
+                .getActiveSubscriptionInfoList();
 
         mHandlerThread = new HandlerThread("KeepaliveStatsTrackerTest");
         mHandlerThread.start();
         mTestHandler = new Handler(mHandlerThread.getLooper());
 
-        setUptimeMillis(0);
-        mKeepaliveStatsTracker = new KeepaliveStatsTracker(mTestHandler, mDependencies);
+        setElapsedRealtime(0);
+        mKeepaliveStatsTracker = new KeepaliveStatsTracker(mContext, mTestHandler, mDependencies);
+        HandlerUtils.waitForIdle(BackgroundThread.getHandler(), TIMEOUT_MS);
+
+        // Initial onSubscriptionsChanged.
+        getOnSubscriptionsChangedListener().onSubscriptionsChanged();
+        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
     }
 
-    private void setUptimeMillis(long time) {
-        doReturn(time).when(mDependencies).getUptimeMillis();
+    private void setElapsedRealtime(long time) {
+        doReturn(time).when(mDependencies).getElapsedRealtime();
     }
 
     private DailykeepaliveInfoReported buildKeepaliveMetrics(long time) {
-        setUptimeMillis(time);
+        setElapsedRealtime(time);
 
         return visibleOnHandlerThread(
                 mTestHandler, () -> mKeepaliveStatsTracker.buildKeepaliveMetrics());
     }
 
     private DailykeepaliveInfoReported buildAndResetMetrics(long time) {
-        setUptimeMillis(time);
+        setElapsedRealtime(time);
 
         return visibleOnHandlerThread(
-                mTestHandler,
-                () -> {
-                    final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
-                            mKeepaliveStatsTracker.buildKeepaliveMetrics();
-                    mKeepaliveStatsTracker.resetMetrics();
-                    return dailyKeepaliveInfoReported;
-                });
+                mTestHandler, () -> mKeepaliveStatsTracker.buildAndResetMetrics());
     }
 
-    private void onStartKeepalive(long time) {
-        setUptimeMillis(time);
-        visibleOnHandlerThread(mTestHandler, () -> mKeepaliveStatsTracker.onStartKeepalive());
+    private void onStartKeepalive(long time, int slot) {
+        onStartKeepalive(time, slot, TEST_KEEPALIVE_INTERVAL_SEC);
     }
 
-    private void onPauseKeepalive(long time) {
-        setUptimeMillis(time);
-        visibleOnHandlerThread(mTestHandler, () -> mKeepaliveStatsTracker.onPauseKeepalive());
+    private void onStartKeepalive(long time, int slot, int intervalSeconds) {
+        onStartKeepalive(time, slot, TEST_NETWORK_CAPABILITIES, intervalSeconds);
     }
 
-    private void onResumeKeepalive(long time) {
-        setUptimeMillis(time);
-        visibleOnHandlerThread(mTestHandler, () -> mKeepaliveStatsTracker.onResumeKeepalive());
+    private void onStartKeepalive(long time, int slot, NetworkCapabilities nc) {
+        onStartKeepalive(time, slot, nc, TEST_KEEPALIVE_INTERVAL_SEC);
     }
 
-    private void onStopKeepalive(long time, boolean wasActive) {
-        setUptimeMillis(time);
+    private void onStartKeepalive(
+            long time, int slot, NetworkCapabilities nc, int intervalSeconds) {
+        onStartKeepalive(time, slot, nc, intervalSeconds, TEST_UID, /* isAutoKeepalive= */ true);
+    }
+
+    private void onStartKeepalive(long time, int slot, NetworkCapabilities nc, int intervalSeconds,
+            int uid, boolean isAutoKeepalive) {
+        setElapsedRealtime(time);
+        visibleOnHandlerThread(mTestHandler, () ->
+                mKeepaliveStatsTracker.onStartKeepalive(TEST_NETWORK, slot, nc, intervalSeconds,
+                        uid, isAutoKeepalive));
+    }
+
+    private void onPauseKeepalive(long time, int slot) {
+        setElapsedRealtime(time);
         visibleOnHandlerThread(
-                mTestHandler, () -> mKeepaliveStatsTracker.onStopKeepalive(wasActive));
+                mTestHandler, () -> mKeepaliveStatsTracker.onPauseKeepalive(TEST_NETWORK, slot));
+    }
+
+    private void onResumeKeepalive(long time, int slot) {
+        setElapsedRealtime(time);
+        visibleOnHandlerThread(
+                mTestHandler, () -> mKeepaliveStatsTracker.onResumeKeepalive(TEST_NETWORK, slot));
+    }
+
+    private void onStopKeepalive(long time, int slot) {
+        setElapsedRealtime(time);
+        visibleOnHandlerThread(
+                mTestHandler, () -> mKeepaliveStatsTracker.onStopKeepalive(TEST_NETWORK, slot));
     }
 
     @Test
     public void testEnsureRunningOnHandlerThread() {
         // Not running on handler thread
-        assertThrows(IllegalStateException.class, () -> mKeepaliveStatsTracker.onStartKeepalive());
-        assertThrows(IllegalStateException.class, () -> mKeepaliveStatsTracker.onPauseKeepalive());
-        assertThrows(IllegalStateException.class, () -> mKeepaliveStatsTracker.onResumeKeepalive());
         assertThrows(
-                IllegalStateException.class, () -> mKeepaliveStatsTracker.onStopKeepalive(true));
+                IllegalStateException.class,
+                () -> mKeepaliveStatsTracker.onStartKeepalive(
+                        TEST_NETWORK,
+                        TEST_SLOT,
+                        TEST_NETWORK_CAPABILITIES,
+                        TEST_KEEPALIVE_INTERVAL_SEC,
+                        TEST_UID,
+                        /* isAutoKeepalive */ true));
+        assertThrows(
+                IllegalStateException.class,
+                () -> mKeepaliveStatsTracker.onPauseKeepalive(TEST_NETWORK, TEST_SLOT));
+        assertThrows(
+                IllegalStateException.class,
+                () -> mKeepaliveStatsTracker.onResumeKeepalive(TEST_NETWORK, TEST_SLOT));
+        assertThrows(
+                IllegalStateException.class,
+                () -> mKeepaliveStatsTracker.onStopKeepalive(TEST_NETWORK, TEST_SLOT));
         assertThrows(
                 IllegalStateException.class, () -> mKeepaliveStatsTracker.buildKeepaliveMetrics());
         assertThrows(
-                IllegalStateException.class, () -> mKeepaliveStatsTracker.resetMetrics());
+                IllegalStateException.class, () -> mKeepaliveStatsTracker.buildAndResetMetrics());
     }
 
     /**
@@ -133,45 +319,112 @@
      * @param expectActiveDurations integer array where the index is the number of concurrent
      *     keepalives and the value is the expected duration of time that the tracker is in a state
      *     with the given number of keepalives active.
-     * @param resultDurationsPerNumOfKeepalive the DurationPerNumOfKeepalive message to assert.
+     * @param actualDurationsPerNumOfKeepalive the DurationPerNumOfKeepalive message to assert.
      */
     private void assertDurationMetrics(
             int[] expectRegisteredDurations,
             int[] expectActiveDurations,
-            DurationPerNumOfKeepalive resultDurationsPerNumOfKeepalive) {
+            DurationPerNumOfKeepalive actualDurationsPerNumOfKeepalive) {
         final int maxNumOfKeepalive = expectRegisteredDurations.length;
         assertEquals(maxNumOfKeepalive, expectActiveDurations.length);
         assertEquals(
                 maxNumOfKeepalive,
-                resultDurationsPerNumOfKeepalive.getDurationForNumOfKeepaliveCount());
+                actualDurationsPerNumOfKeepalive.getDurationForNumOfKeepaliveCount());
         for (int numOfKeepalive = 0; numOfKeepalive < maxNumOfKeepalive; numOfKeepalive++) {
-            final DurationForNumOfKeepalive resultDurations =
-                    resultDurationsPerNumOfKeepalive.getDurationForNumOfKeepalive(numOfKeepalive);
+            final DurationForNumOfKeepalive actualDurations =
+                    actualDurationsPerNumOfKeepalive.getDurationForNumOfKeepalive(numOfKeepalive);
 
-            assertEquals(numOfKeepalive, resultDurations.getNumOfKeepalive());
+            assertEquals(numOfKeepalive, actualDurations.getNumOfKeepalive());
             assertEquals(
                     expectRegisteredDurations[numOfKeepalive],
-                    resultDurations.getKeepaliveRegisteredDurationsMsec());
+                    actualDurations.getKeepaliveRegisteredDurationsMsec());
             assertEquals(
                     expectActiveDurations[numOfKeepalive],
-                    resultDurations.getKeepaliveActiveDurationsMsec());
+                    actualDurations.getKeepaliveActiveDurationsMsec());
+        }
+    }
+
+    /**
+     * Asserts the actual KeepaliveLifetimePerCarrier contains an expected KeepaliveCarrierStats.
+     * This finds and checks only for the (carrierId, transportTypes, intervalMs) of the given
+     * expectKeepaliveCarrierStats and asserts the lifetime metrics.
+     *
+     * @param expectKeepaliveCarrierStats a keepalive lifetime metric that is expected to be in the
+     *     proto.
+     * @param actualKeepaliveLifetimePerCarrier the KeepaliveLifetimePerCarrier message to assert.
+     */
+    private void findAndAssertCarrierLifetimeMetrics(
+            KeepaliveCarrierStats expectKeepaliveCarrierStats,
+            KeepaliveLifetimePerCarrier actualKeepaliveLifetimePerCarrier) {
+        for (KeepaliveLifetimeForCarrier keepaliveLifetimeForCarrier :
+                actualKeepaliveLifetimePerCarrier.getKeepaliveLifetimeForCarrierList()) {
+            if (expectKeepaliveCarrierStats.carrierId == keepaliveLifetimeForCarrier.getCarrierId()
+                    && expectKeepaliveCarrierStats.transportTypes
+                            == keepaliveLifetimeForCarrier.getTransportTypes()
+                    && expectKeepaliveCarrierStats.intervalMs
+                            == keepaliveLifetimeForCarrier.getIntervalsMsec()) {
+                assertEquals(
+                        expectKeepaliveCarrierStats.lifetimeMs,
+                        keepaliveLifetimeForCarrier.getLifetimeMsec());
+                assertEquals(
+                        expectKeepaliveCarrierStats.activeLifetimeMs,
+                        keepaliveLifetimeForCarrier.getActiveLifetimeMsec());
+                return;
+            }
+        }
+        fail("KeepaliveLifetimeForCarrier not found for a given expected KeepaliveCarrierStats");
+    }
+
+    private void assertNoDuplicates(Object[] arr) {
+        final Set<Object> s = new HashSet<Object>(Arrays.asList(arr));
+        assertEquals(arr.length, s.size());
+    }
+
+    /**
+     * Asserts that a KeepaliveLifetimePerCarrier contains all the expected KeepaliveCarrierStats.
+     *
+     * @param expectKeepaliveCarrierStatsArray an array of keepalive lifetime metrics that is
+     *     expected to be in the KeepaliveLifetimePerCarrier.
+     * @param actualKeepaliveLifetimePerCarrier the KeepaliveLifetimePerCarrier message to assert.
+     */
+    private void assertCarrierLifetimeMetrics(
+            KeepaliveCarrierStats[] expectKeepaliveCarrierStatsArray,
+            KeepaliveLifetimePerCarrier actualKeepaliveLifetimePerCarrier) {
+        assertNoDuplicates(expectKeepaliveCarrierStatsArray);
+        assertEquals(
+                expectKeepaliveCarrierStatsArray.length,
+                actualKeepaliveLifetimePerCarrier.getKeepaliveLifetimeForCarrierCount());
+        for (KeepaliveCarrierStats keepaliveCarrierStats : expectKeepaliveCarrierStatsArray) {
+            findAndAssertCarrierLifetimeMetrics(
+                    keepaliveCarrierStats, actualKeepaliveLifetimePerCarrier);
         }
     }
 
     private void assertDailyKeepaliveInfoReported(
             DailykeepaliveInfoReported dailyKeepaliveInfoReported,
+            int expectRequestsCount,
+            int expectAutoRequestsCount,
+            int[] expectAppUids,
             int[] expectRegisteredDurations,
-            int[] expectActiveDurations) {
-        // TODO(b/273451360) Assert these values when they are filled.
-        assertFalse(dailyKeepaliveInfoReported.hasKeepaliveLifetimePerCarrier());
-        assertFalse(dailyKeepaliveInfoReported.hasKeepaliveRequests());
-        assertFalse(dailyKeepaliveInfoReported.hasAutomaticKeepaliveRequests());
-        assertFalse(dailyKeepaliveInfoReported.hasDistinctUserCount());
-        assertTrue(dailyKeepaliveInfoReported.getUidList().isEmpty());
+            int[] expectActiveDurations,
+            KeepaliveCarrierStats[] expectKeepaliveCarrierStatsArray) {
+        assertEquals(expectRequestsCount, dailyKeepaliveInfoReported.getKeepaliveRequests());
+        assertEquals(
+                expectAutoRequestsCount,
+                dailyKeepaliveInfoReported.getAutomaticKeepaliveRequests());
+        assertEquals(expectAppUids.length, dailyKeepaliveInfoReported.getDistinctUserCount());
 
-        final DurationPerNumOfKeepalive resultDurations =
+        final int[] uidArray = CollectionUtils.toIntArray(dailyKeepaliveInfoReported.getUidList());
+        assertArrayEquals(expectAppUids, uidArray);
+
+        final DurationPerNumOfKeepalive actualDurations =
                 dailyKeepaliveInfoReported.getDurationPerNumOfKeepalive();
-        assertDurationMetrics(expectRegisteredDurations, expectActiveDurations, resultDurations);
+        assertDurationMetrics(expectRegisteredDurations, expectActiveDurations, actualDurations);
+
+        final KeepaliveLifetimePerCarrier actualCarrierLifetime =
+                dailyKeepaliveInfoReported.getKeepaliveLifetimePerCarrier();
+
+        assertCarrierLifetimeMetrics(expectKeepaliveCarrierStatsArray, actualCarrierLifetime);
     }
 
     @Test
@@ -187,8 +440,12 @@
 
         assertDailyKeepaliveInfoReported(
                 dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 0,
+                /* expectAutoRequestsCount= */ 0,
+                /* expectAppUids= */ new int[0],
                 expectRegisteredDurations,
-                expectActiveDurations);
+                expectActiveDurations,
+                new KeepaliveCarrierStats[0]);
     }
 
     /*
@@ -203,7 +460,7 @@
         final int startTime = 1000;
         final int writeTime = 5000;
 
-        onStartKeepalive(startTime);
+        onStartKeepalive(startTime, TEST_SLOT);
 
         final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
                 buildKeepaliveMetrics(writeTime);
@@ -214,8 +471,14 @@
         final int[] expectActiveDurations = new int[] {startTime, writeTime - startTime};
         assertDailyKeepaliveInfoReported(
                 dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 1,
+                /* expectAutoRequestsCount= */ 1,
+                /* expectAppUids= */ new int[] {TEST_UID},
                 expectRegisteredDurations,
-                expectActiveDurations);
+                expectActiveDurations,
+                new KeepaliveCarrierStats[] {
+                    getDefaultCarrierStats(expectRegisteredDurations[1], expectActiveDurations[1])
+                });
     }
 
     /*
@@ -231,9 +494,9 @@
         final int pauseTime = 2030;
         final int writeTime = 5000;
 
-        onStartKeepalive(startTime);
+        onStartKeepalive(startTime, TEST_SLOT);
 
-        onPauseKeepalive(pauseTime);
+        onPauseKeepalive(pauseTime, TEST_SLOT);
 
         final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
                 buildKeepaliveMetrics(writeTime);
@@ -246,8 +509,14 @@
                 new int[] {startTime + (writeTime - pauseTime), pauseTime - startTime};
         assertDailyKeepaliveInfoReported(
                 dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 1,
+                /* expectAutoRequestsCount= */ 1,
+                /* expectAppUids= */ new int[] {TEST_UID},
                 expectRegisteredDurations,
-                expectActiveDurations);
+                expectActiveDurations,
+                new KeepaliveCarrierStats[] {
+                    getDefaultCarrierStats(expectRegisteredDurations[1], expectActiveDurations[1])
+                });
     }
 
     /*
@@ -264,11 +533,11 @@
         final int resumeTime = 3450;
         final int writeTime = 5000;
 
-        onStartKeepalive(startTime);
+        onStartKeepalive(startTime, TEST_SLOT);
 
-        onPauseKeepalive(pauseTime);
+        onPauseKeepalive(pauseTime, TEST_SLOT);
 
-        onResumeKeepalive(resumeTime);
+        onResumeKeepalive(resumeTime, TEST_SLOT);
 
         final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
                 buildKeepaliveMetrics(writeTime);
@@ -284,8 +553,14 @@
                 };
         assertDailyKeepaliveInfoReported(
                 dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 1,
+                /* expectAutoRequestsCount= */ 1,
+                /* expectAppUids= */ new int[] {TEST_UID},
                 expectRegisteredDurations,
-                expectActiveDurations);
+                expectActiveDurations,
+                new KeepaliveCarrierStats[] {
+                    getDefaultCarrierStats(expectRegisteredDurations[1], expectActiveDurations[1])
+                });
     }
 
     /*
@@ -303,13 +578,13 @@
         final int stopTime = 4157;
         final int writeTime = 5000;
 
-        onStartKeepalive(startTime);
+        onStartKeepalive(startTime, TEST_SLOT);
 
-        onPauseKeepalive(pauseTime);
+        onPauseKeepalive(pauseTime, TEST_SLOT);
 
-        onResumeKeepalive(resumeTime);
+        onResumeKeepalive(resumeTime, TEST_SLOT);
 
-        onStopKeepalive(stopTime, /* wasActive= */ true);
+        onStopKeepalive(stopTime, TEST_SLOT);
 
         final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
                 buildKeepaliveMetrics(writeTime);
@@ -326,8 +601,14 @@
                 };
         assertDailyKeepaliveInfoReported(
                 dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 1,
+                /* expectAutoRequestsCount= */ 1,
+                /* expectAppUids= */ new int[] {TEST_UID},
                 expectRegisteredDurations,
-                expectActiveDurations);
+                expectActiveDurations,
+                new KeepaliveCarrierStats[] {
+                    getDefaultCarrierStats(expectRegisteredDurations[1], expectActiveDurations[1])
+                });
     }
 
     /*
@@ -344,11 +625,11 @@
         final int stopTime = 4157;
         final int writeTime = 5000;
 
-        onStartKeepalive(startTime);
+        onStartKeepalive(startTime, TEST_SLOT);
 
-        onPauseKeepalive(pauseTime);
+        onPauseKeepalive(pauseTime, TEST_SLOT);
 
-        onStopKeepalive(stopTime, /* wasActive= */ false);
+        onStopKeepalive(stopTime, TEST_SLOT);
 
         final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
                 buildKeepaliveMetrics(writeTime);
@@ -362,8 +643,14 @@
                 new int[] {startTime + (writeTime - pauseTime), (pauseTime - startTime)};
         assertDailyKeepaliveInfoReported(
                 dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 1,
+                /* expectAutoRequestsCount= */ 1,
+                /* expectAppUids= */ new int[] {TEST_UID},
                 expectRegisteredDurations,
-                expectActiveDurations);
+                expectActiveDurations,
+                new KeepaliveCarrierStats[] {
+                    getDefaultCarrierStats(expectRegisteredDurations[1], expectActiveDurations[1])
+                });
     }
 
     /*
@@ -381,17 +668,17 @@
         final int stopTime = 4000;
         final int writeTime = 5000;
 
-        onStartKeepalive(startTime);
+        onStartKeepalive(startTime, TEST_SLOT);
 
         for (int i = 0; i < pauseResumeTimes.length; i++) {
             if (i % 2 == 0) {
-                onPauseKeepalive(pauseResumeTimes[i]);
+                onPauseKeepalive(pauseResumeTimes[i], TEST_SLOT);
             } else {
-                onResumeKeepalive(pauseResumeTimes[i]);
+                onResumeKeepalive(pauseResumeTimes[i], TEST_SLOT);
             }
         }
 
-        onStopKeepalive(stopTime, /* wasActive= */ true);
+        onStopKeepalive(stopTime, TEST_SLOT);
 
         final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
                 buildKeepaliveMetrics(writeTime);
@@ -407,8 +694,14 @@
                 };
         assertDailyKeepaliveInfoReported(
                 dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 1,
+                /* expectAutoRequestsCount= */ 1,
+                /* expectAppUids= */ new int[] {TEST_UID},
                 expectRegisteredDurations,
-                expectActiveDurations);
+                expectActiveDurations,
+                new KeepaliveCarrierStats[] {
+                    getDefaultCarrierStats(expectRegisteredDurations[1], expectActiveDurations[1])
+                });
     }
 
     /*
@@ -431,19 +724,19 @@
         final int stopTime1 = 4157;
         final int writeTime = 5000;
 
-        onStartKeepalive(startTime1);
+        onStartKeepalive(startTime1, TEST_SLOT);
 
-        onPauseKeepalive(pauseTime1);
+        onPauseKeepalive(pauseTime1, TEST_SLOT);
 
-        onStartKeepalive(startTime2);
+        onStartKeepalive(startTime2, TEST_SLOT2);
 
-        onResumeKeepalive(resumeTime1);
+        onResumeKeepalive(resumeTime1, TEST_SLOT);
 
-        onPauseKeepalive(pauseTime2);
+        onPauseKeepalive(pauseTime2, TEST_SLOT2);
 
-        onResumeKeepalive(resumeTime2);
+        onResumeKeepalive(resumeTime2, TEST_SLOT2);
 
-        onStopKeepalive(stopTime1, /* wasActive= */ true);
+        onStopKeepalive(stopTime1, TEST_SLOT);
 
         final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
                 buildKeepaliveMetrics(writeTime);
@@ -474,10 +767,21 @@
                     // 2 active keepalives before keepalive2 is paused and before keepalive1 stops.
                     (pauseTime2 - resumeTime1) + (stopTime1 - resumeTime2)
                 };
+
         assertDailyKeepaliveInfoReported(
                 dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 2,
+                /* expectAutoRequestsCount= */ 2,
+                /* expectAppUids= */ new int[] {TEST_UID},
                 expectRegisteredDurations,
-                expectActiveDurations);
+                expectActiveDurations,
+                // The carrier stats are aggregated here since the keepalives have the same
+                // (carrierId, transportTypes, intervalMs).
+                new KeepaliveCarrierStats[] {
+                    getDefaultCarrierStats(
+                            expectRegisteredDurations[1] + 2 * expectRegisteredDurations[2],
+                            expectActiveDurations[1] + 2 * expectActiveDurations[2])
+                });
     }
 
     /*
@@ -494,7 +798,7 @@
         final int stopTime = 7000;
         final int writeTime2 = 10000;
 
-        onStartKeepalive(startTime);
+        onStartKeepalive(startTime, TEST_SLOT);
 
         final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
                 buildAndResetMetrics(writeTime);
@@ -504,20 +808,31 @@
         final int[] expectActiveDurations = new int[] {startTime, writeTime - startTime};
         assertDailyKeepaliveInfoReported(
                 dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 1,
+                /* expectAutoRequestsCount= */ 1,
+                /* expectAppUids= */ new int[] {TEST_UID},
                 expectRegisteredDurations,
-                expectActiveDurations);
+                expectActiveDurations,
+                new KeepaliveCarrierStats[] {
+                    getDefaultCarrierStats(expectRegisteredDurations[1], expectActiveDurations[1])
+                });
 
+        // Check metrics was reset from above.
         final DailykeepaliveInfoReported dailyKeepaliveInfoReported2 =
                 buildKeepaliveMetrics(writeTime);
 
         // Expect the stored durations to be 0 but still contain the number of keepalive = 1.
         assertDailyKeepaliveInfoReported(
                 dailyKeepaliveInfoReported2,
+                /* expectRequestsCount= */ 1,
+                /* expectAutoRequestsCount= */ 1,
+                /* expectAppUids= */ new int[] {TEST_UID},
                 /* expectRegisteredDurations= */ new int[] {0, 0},
-                /* expectActiveDurations= */ new int[] {0, 0});
+                /* expectActiveDurations= */ new int[] {0, 0},
+                new KeepaliveCarrierStats[] {getDefaultCarrierStats(0, 0)});
 
         // Expect that the keepalive is still registered after resetting so it can be stopped.
-        onStopKeepalive(stopTime, /* wasActive= */ true);
+        onStopKeepalive(stopTime, TEST_SLOT);
 
         final DailykeepaliveInfoReported dailyKeepaliveInfoReported3 =
                 buildKeepaliveMetrics(writeTime2);
@@ -528,7 +843,353 @@
                 new int[] {writeTime2 - stopTime, stopTime - writeTime};
         assertDailyKeepaliveInfoReported(
                 dailyKeepaliveInfoReported3,
+                /* expectRequestsCount= */ 1,
+                /* expectAutoRequestsCount= */ 1,
+                /* expectAppUids= */ new int[] {TEST_UID},
                 expectRegisteredDurations2,
-                expectActiveDurations2);
+                expectActiveDurations2,
+                new KeepaliveCarrierStats[] {
+                    getDefaultCarrierStats(expectRegisteredDurations2[1], expectActiveDurations2[1])
+                });
+    }
+
+    /*
+     * Diagram of test (not to scale):
+     * Key: S - Start/Stop, P - Pause, R - Resume, W - Write
+     *
+     * Keepalive1     S1      S1  W+reset         W
+     * Keepalive2         S2      W+reset         W
+     * Timeline    |------------------------------|
+     */
+    @Test
+    public void testResetMetrics_twoKeepalives() {
+        final int startTime1 = 1000;
+        final int startTime2 = 2000;
+        final int stopTime1 = 4157;
+        final int writeTime = 5000;
+        final int writeTime2 = 10000;
+
+        onStartKeepalive(startTime1, TEST_SLOT);
+
+        onStartKeepalive(startTime2, TEST_SLOT2, TEST_NETWORK_CAPABILITIES_2,
+                TEST_KEEPALIVE_INTERVAL2_SEC);
+
+        onStopKeepalive(stopTime1, TEST_SLOT);
+
+        final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
+                buildAndResetMetrics(writeTime);
+
+        final int[] expectRegisteredDurations =
+                new int[] {
+                    startTime1,
+                    // 1 keepalive before keepalive2 starts and after keepalive1 stops.
+                    (startTime2 - startTime1) + (writeTime - stopTime1),
+                    stopTime1 - startTime2
+                };
+        // Since there is no pause, expect the same as registered durations.
+        final int[] expectActiveDurations =
+                new int[] {
+                    startTime1,
+                    (startTime2 - startTime1) + (writeTime - stopTime1),
+                    stopTime1 - startTime2
+                };
+
+        // Lifetime carrier stats are independent of each other since they have different intervals.
+        final KeepaliveCarrierStats expectKeepaliveCarrierStats1 =
+                getDefaultCarrierStats(stopTime1 - startTime1, stopTime1 - startTime1);
+        final KeepaliveCarrierStats expectKeepaliveCarrierStats2 =
+                new KeepaliveCarrierStats(
+                        TEST_CARRIER_ID_2,
+                        /* transportTypes= */ (1 << TRANSPORT_CELLULAR),
+                        TEST_KEEPALIVE_INTERVAL2_SEC * 1000,
+                        writeTime - startTime2,
+                        writeTime - startTime2);
+
+        assertDailyKeepaliveInfoReported(
+                dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 2,
+                /* expectAutoRequestsCount= */ 2,
+                /* expectAppUids= */ new int[] {TEST_UID},
+                expectRegisteredDurations,
+                expectActiveDurations,
+                new KeepaliveCarrierStats[] {
+                    expectKeepaliveCarrierStats1, expectKeepaliveCarrierStats2
+                });
+
+        final DailykeepaliveInfoReported dailyKeepaliveInfoReported2 =
+                buildKeepaliveMetrics(writeTime2);
+
+        // Only 1 keepalive is registered and active since the reset until the writeTime2.
+        final int[] expectRegisteredDurations2 = new int[] {0, writeTime2 - writeTime};
+        final int[] expectActiveDurations2 = new int[] {0, writeTime2 - writeTime};
+
+        // Only the keepalive with interval of intervalSec2 is present.
+        final KeepaliveCarrierStats expectKeepaliveCarrierStats3 =
+                new KeepaliveCarrierStats(
+                        TEST_CARRIER_ID_2,
+                        /* transportTypes= */ (1 << TRANSPORT_CELLULAR),
+                        TEST_KEEPALIVE_INTERVAL2_SEC * 1000,
+                        writeTime2 - writeTime,
+                        writeTime2 - writeTime);
+
+        assertDailyKeepaliveInfoReported(
+                dailyKeepaliveInfoReported2,
+                /* expectRequestsCount= */ 1,
+                /* expectAutoRequestsCount= */ 1,
+                /* expectAppUids= */ new int[] {TEST_UID},
+                expectRegisteredDurations2,
+                expectActiveDurations2,
+                new KeepaliveCarrierStats[] {expectKeepaliveCarrierStats3});
+    }
+
+    @Test
+    public void testReusableSlot_keepaliveNotStopped() {
+        final int startTime1 = 1000;
+        final int startTime2 = 2000;
+        final int writeTime = 5000;
+
+        onStartKeepalive(startTime1, TEST_SLOT);
+
+        // Attempt to use the same (network, slot)
+        assertThrows(IllegalArgumentException.class, () -> onStartKeepalive(startTime2, TEST_SLOT));
+
+        final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
+                buildKeepaliveMetrics(writeTime);
+
+        // Expect the duration to be from startTime1 and not startTime2, it should not start again.
+        final int[] expectRegisteredDurations = new int[] {startTime1, writeTime - startTime1};
+        final int[] expectActiveDurations = new int[] {startTime1, writeTime - startTime1};
+
+        assertDailyKeepaliveInfoReported(
+                dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 1,
+                /* expectAutoRequestsCount= */ 1,
+                /* expectAppUids= */ new int[] {TEST_UID},
+                expectRegisteredDurations,
+                expectActiveDurations,
+                new KeepaliveCarrierStats[] {
+                    getDefaultCarrierStats(expectRegisteredDurations[1], expectActiveDurations[1])
+                });
+    }
+
+    @Test
+    public void testReusableSlot_keepaliveStopped() {
+        final int startTime1 = 1000;
+        final int stopTime = 2000;
+        final int startTime2 = 3000;
+        final int writeTime = 5000;
+
+        onStartKeepalive(startTime1, TEST_SLOT);
+
+        onStopKeepalive(stopTime, TEST_SLOT);
+
+        // Attempt to use the same (network, slot)
+        onStartKeepalive(startTime2, TEST_SLOT);
+
+        final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
+                buildKeepaliveMetrics(writeTime);
+
+        // Expect the durations to be an aggregate of both periods.
+        // i.e. onStartKeepalive works on the same (network, slot) if it has been stopped.
+        final int[] expectRegisteredDurations =
+                new int[] {
+                    startTime1 + (startTime2 - stopTime),
+                    (stopTime - startTime1) + (writeTime - startTime2)
+                };
+        final int[] expectActiveDurations =
+                new int[] {
+                    startTime1 + (startTime2 - stopTime),
+                    (stopTime - startTime1) + (writeTime - startTime2)
+                };
+
+        assertDailyKeepaliveInfoReported(
+                dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 2,
+                /* expectAutoRequestsCount= */ 2,
+                /* expectAppUids= */ new int[] {TEST_UID},
+                expectRegisteredDurations,
+                expectActiveDurations,
+                new KeepaliveCarrierStats[] {
+                    getDefaultCarrierStats(expectRegisteredDurations[1], expectActiveDurations[1])
+                });
+    }
+
+    @Test
+    public void testCarrierIdChange_changeBeforeStart() {
+        // Update the list to only have sub_id_2 with carrier_id_1.
+        final SubscriptionInfo subInfo = makeSubInfoMock(TEST_SUB_ID_2, TEST_CARRIER_ID_1);
+        doReturn(List.of(subInfo)).when(mSubscriptionManager).getActiveSubscriptionInfoList();
+
+        getOnSubscriptionsChangedListener().onSubscriptionsChanged();
+        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+
+        final int startTime = 1000;
+        final int writeTime = 5000;
+
+        onStartKeepalive(startTime, TEST_SLOT, TEST_NETWORK_CAPABILITIES);
+        onStartKeepalive(startTime, TEST_SLOT2, TEST_NETWORK_CAPABILITIES_2);
+
+        final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
+                buildKeepaliveMetrics(writeTime);
+
+        // The network with sub_id_1 has an unknown carrier id.
+        final KeepaliveCarrierStats expectKeepaliveCarrierStats1 =
+                new KeepaliveCarrierStats(
+                        TelephonyManager.UNKNOWN_CARRIER_ID,
+                        /* transportTypes= */ (1 << TRANSPORT_CELLULAR),
+                        TEST_KEEPALIVE_INTERVAL_SEC * 1000,
+                        writeTime - startTime,
+                        writeTime - startTime);
+
+        // The network with sub_id_2 has carrier_id_1.
+        final KeepaliveCarrierStats expectKeepaliveCarrierStats2 =
+                new KeepaliveCarrierStats(
+                        TEST_CARRIER_ID_1,
+                        /* transportTypes= */ (1 << TRANSPORT_CELLULAR),
+                        TEST_KEEPALIVE_INTERVAL_SEC * 1000,
+                        writeTime - startTime,
+                        writeTime - startTime);
+        assertDailyKeepaliveInfoReported(
+                dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 2,
+                /* expectAutoRequestsCount= */ 2,
+                /* expectAppUids= */ new int[] {TEST_UID},
+                /* expectRegisteredDurations= */ new int[] {startTime, 0, writeTime - startTime},
+                /* expectActiveDurations= */ new int[] {startTime, 0, writeTime - startTime},
+                new KeepaliveCarrierStats[] {
+                    expectKeepaliveCarrierStats1, expectKeepaliveCarrierStats2
+                });
+    }
+
+    @Test
+    public void testCarrierIdFromWifiInfo() {
+        final int startTime = 1000;
+        final int writeTime = 5000;
+
+        final WifiInfo wifiInfo = mock(WifiInfo.class);
+        final WifiInfo wifiInfoCopy = mock(WifiInfo.class);
+
+        // Building NetworkCapabilities stores a copy of the WifiInfo with makeCopy.
+        doReturn(wifiInfoCopy).when(wifiInfo).makeCopy(anyLong());
+        doReturn(TEST_SUB_ID_1).when(wifiInfo).getSubscriptionId();
+        doReturn(TEST_SUB_ID_1).when(wifiInfoCopy).getSubscriptionId();
+        final NetworkCapabilities nc =
+                new NetworkCapabilities.Builder()
+                        .addTransportType(TRANSPORT_WIFI)
+                        .setTransportInfo(wifiInfo)
+                        .build();
+
+        onStartKeepalive(startTime, TEST_SLOT, nc);
+
+        final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
+                buildKeepaliveMetrics(writeTime);
+
+        final KeepaliveCarrierStats expectKeepaliveCarrierStats =
+                new KeepaliveCarrierStats(
+                        TEST_CARRIER_ID_1,
+                        /* transportTypes= */ (1 << TRANSPORT_WIFI),
+                        TEST_KEEPALIVE_INTERVAL_SEC * 1000,
+                        writeTime - startTime,
+                        writeTime - startTime);
+
+        assertDailyKeepaliveInfoReported(
+                dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 1,
+                /* expectAutoRequestsCount= */ 1,
+                /* expectAppUids= */ new int[] {TEST_UID},
+                /* expectRegisteredDurations= */ new int[] {startTime, writeTime - startTime},
+                /* expectActiveDurations= */ new int[] {startTime, writeTime - startTime},
+                new KeepaliveCarrierStats[] {expectKeepaliveCarrierStats});
+    }
+
+    @Test
+    public void testKeepaliveCountsAndUids() {
+        final int startTime1 = 1000, startTime2 = 2000, startTime3 = 3000;
+        final int writeTime = 5000;
+        final int[] uids = new int[] {TEST_UID, TEST_UID + 1, TEST_UID + 2};
+        onStartKeepalive(startTime1, TEST_SLOT, TEST_NETWORK_CAPABILITIES,
+                TEST_KEEPALIVE_INTERVAL_SEC, uids[0], /* isAutoKeepalive= */ true);
+        onStartKeepalive(startTime2, TEST_SLOT + 1, TEST_NETWORK_CAPABILITIES,
+                TEST_KEEPALIVE_INTERVAL_SEC, uids[1], /* isAutoKeepalive= */ false);
+        onStartKeepalive(startTime3, TEST_SLOT + 2, TEST_NETWORK_CAPABILITIES,
+                TEST_KEEPALIVE_INTERVAL_SEC, uids[2], /* isAutoKeepalive= */ true);
+
+        final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
+                buildKeepaliveMetrics(writeTime);
+        final int[] expectRegisteredDurations =
+                new int[] {
+                    startTime1,
+                    (startTime2 - startTime1),
+                    (startTime3 - startTime2),
+                    (writeTime - startTime3)
+                };
+        final int[] expectActiveDurations =
+                new int[] {
+                    startTime1,
+                    (startTime2 - startTime1),
+                    (startTime3 - startTime2),
+                    (writeTime - startTime3)
+                };
+        assertDailyKeepaliveInfoReported(
+                dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 3,
+                /* expectAutoRequestsCount= */ 2,
+                /* expectAppUids= */ uids,
+                expectRegisteredDurations,
+                expectActiveDurations,
+                new KeepaliveCarrierStats[] {
+                    getDefaultCarrierStats(
+                            writeTime * 3 - startTime1 - startTime2 - startTime3,
+                            writeTime * 3 - startTime1 - startTime2 - startTime3)
+                });
+    }
+
+    @Test
+    public void testUpdateDefaultSubId() {
+        final int startTime1 = 1000;
+        final int startTime2 = 3000;
+        final int writeTime = 5000;
+
+        // No TelephonyNetworkSpecifier set with subId to force the use of default subId.
+        final NetworkCapabilities nc =
+                new NetworkCapabilities.Builder().addTransportType(TRANSPORT_CELLULAR).build();
+        onStartKeepalive(startTime1, TEST_SLOT, nc);
+        // Update default subId
+        triggerBroadcastDefaultSubId(TEST_SUB_ID_1);
+        onStartKeepalive(startTime2, TEST_SLOT2, nc);
+
+        final DailykeepaliveInfoReported dailyKeepaliveInfoReported =
+                buildKeepaliveMetrics(writeTime);
+
+        final int[] expectRegisteredDurations =
+                new int[] {startTime1, startTime2 - startTime1, writeTime - startTime2};
+        final int[] expectActiveDurations =
+                new int[] {startTime1, startTime2 - startTime1, writeTime - startTime2};
+        // Expect the carrier id of the first keepalive to be unknown
+        final KeepaliveCarrierStats expectKeepaliveCarrierStats1 =
+                new KeepaliveCarrierStats(
+                        TelephonyManager.UNKNOWN_CARRIER_ID,
+                        /* transportTypes= */ (1 << TRANSPORT_CELLULAR),
+                        TEST_KEEPALIVE_INTERVAL_SEC * 1000,
+                        writeTime - startTime1,
+                        writeTime - startTime1);
+        // Expect the carrier id of the second keepalive to be TEST_CARRIER_ID_1, from TEST_SUB_ID_1
+        final KeepaliveCarrierStats expectKeepaliveCarrierStats2 =
+                new KeepaliveCarrierStats(
+                        TEST_CARRIER_ID_1,
+                        /* transportTypes= */ (1 << TRANSPORT_CELLULAR),
+                        TEST_KEEPALIVE_INTERVAL_SEC * 1000,
+                        writeTime - startTime2,
+                        writeTime - startTime2);
+        assertDailyKeepaliveInfoReported(
+                dailyKeepaliveInfoReported,
+                /* expectRequestsCount= */ 2,
+                /* expectAutoRequestsCount= */ 2,
+                /* expectAppUids= */ new int[] {TEST_UID},
+                expectRegisteredDurations,
+                expectActiveDurations,
+                new KeepaliveCarrierStats[] {
+                    expectKeepaliveCarrierStats1, expectKeepaliveCarrierStats2
+                });
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
index d9acc61..c467f45 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsAdvertiserTest.kt
@@ -56,7 +56,8 @@
 private val TEST_ADDR = parseNumericAddress("2001:db8::123")
 private val TEST_LINKADDR = LinkAddress(TEST_ADDR, 64 /* prefixLength */)
 private val TEST_NETWORK_1 = mock(Network::class.java)
-private val TEST_NETWORK_2 = mock(Network::class.java)
+private val TEST_SOCKETKEY_1 = mock(SocketKey::class.java)
+private val TEST_SOCKETKEY_2 = mock(SocketKey::class.java)
 private val TEST_HOSTNAME = arrayOf("Android_test", "local")
 private const val TEST_SUBTYPE = "_subtype"
 
@@ -145,7 +146,7 @@
         verify(socketProvider).requestSocket(eq(TEST_NETWORK_1), socketCbCaptor.capture())
 
         val socketCb = socketCbCaptor.value
-        postSync { socketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) }
+        postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
 
         val intAdvCbCaptor = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
         verify(mockDeps).makeAdvertiser(
@@ -163,7 +164,7 @@
                 mockInterfaceAdvertiser1, SERVICE_ID_1) }
         verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_1), argThat { it.matches(SERVICE_1) })
 
-        postSync { socketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
+        postSync { socketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
         verify(mockInterfaceAdvertiser1).destroyNow()
     }
 
@@ -177,8 +178,8 @@
                 socketCbCaptor.capture())
 
         val socketCb = socketCbCaptor.value
-        postSync { socketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR)) }
-        postSync { socketCb.onSocketCreated(TEST_NETWORK_2, mockSocket2, listOf(TEST_LINKADDR)) }
+        postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR)) }
+        postSync { socketCb.onSocketCreated(TEST_SOCKETKEY_2, mockSocket2, listOf(TEST_LINKADDR)) }
 
         val intAdvCbCaptor1 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
         val intAdvCbCaptor2 = ArgumentCaptor.forClass(MdnsInterfaceAdvertiser.Callback::class.java)
@@ -241,8 +242,8 @@
 
         // Callbacks for matching network and all networks both get the socket
         postSync {
-            oneNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
-            allNetSocketCb.onSocketCreated(TEST_NETWORK_1, mockSocket1, listOf(TEST_LINKADDR))
+            oneNetSocketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR))
+            allNetSocketCb.onSocketCreated(TEST_SOCKETKEY_1, mockSocket1, listOf(TEST_LINKADDR))
         }
 
         val expectedRenamed = NsdServiceInfo(
@@ -294,8 +295,8 @@
         verify(cb).onRegisterServiceSucceeded(eq(SERVICE_ID_2),
                 argThat { it.matches(expectedRenamed) })
 
-        postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
-        postSync { allNetSocketCb.onInterfaceDestroyed(TEST_NETWORK_1, mockSocket1) }
+        postSync { oneNetSocketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
+        postSync { allNetSocketCb.onInterfaceDestroyed(TEST_SOCKETKEY_1, mockSocket1) }
 
         // destroyNow can be called multiple times
         verify(mockInterfaceAdvertiser1, atLeastOnce()).destroyNow()
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
index a24664e..d2298fe 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsDiscoveryManagerTests.java
@@ -28,7 +28,6 @@
 import static org.mockito.Mockito.when;
 
 import android.annotation.NonNull;
-import android.annotation.Nullable;
 import android.net.Network;
 import android.os.Handler;
 import android.os.HandlerThread;
@@ -65,17 +64,22 @@
     private static final String SERVICE_TYPE_2 = "_test._tcp.local";
     private static final Network NETWORK_1 = Mockito.mock(Network.class);
     private static final Network NETWORK_2 = Mockito.mock(Network.class);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_1_NULL_NETWORK =
-            Pair.create(SERVICE_TYPE_1, null);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_1_NETWORK_1 =
-            Pair.create(SERVICE_TYPE_1, NETWORK_1);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2_NULL_NETWORK =
-            Pair.create(SERVICE_TYPE_2, null);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2_NETWORK_1 =
-            Pair.create(SERVICE_TYPE_2, NETWORK_1);
-    private static final Pair<String, Network> PER_NETWORK_SERVICE_TYPE_2_NETWORK_2 =
-            Pair.create(SERVICE_TYPE_2, NETWORK_2);
-
+    private static final SocketKey SOCKET_KEY_NULL_NETWORK =
+            new SocketKey(null /* network */, 999 /* interfaceIndex */);
+    private static final SocketKey SOCKET_KEY_NETWORK_1 =
+            new SocketKey(NETWORK_1, 998 /* interfaceIndex */);
+    private static final SocketKey SOCKET_KEY_NETWORK_2 =
+            new SocketKey(NETWORK_2, 997 /* interfaceIndex */);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_1_NULL_NETWORK =
+            Pair.create(SERVICE_TYPE_1, SOCKET_KEY_NULL_NETWORK);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_2_NULL_NETWORK =
+            Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NULL_NETWORK);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_1_NETWORK_1 =
+            Pair.create(SERVICE_TYPE_1, SOCKET_KEY_NETWORK_1);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_2_NETWORK_1 =
+            Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NETWORK_1);
+    private static final Pair<String, SocketKey> PER_SOCKET_SERVICE_TYPE_2_NETWORK_2 =
+            Pair.create(SERVICE_TYPE_2, SOCKET_KEY_NETWORK_2);
     @Mock private ExecutorProvider executorProvider;
     @Mock private MdnsSocketClientBase socketClient;
     @Mock private MdnsServiceTypeClient mockServiceTypeClientType1NullNetwork;
@@ -104,22 +108,22 @@
                 sharedLog) {
                     @Override
                     MdnsServiceTypeClient createServiceTypeClient(@NonNull String serviceType,
-                            @Nullable Network network) {
-                        final Pair<String, Network> perNetworkServiceType =
-                                Pair.create(serviceType, network);
-                        if (perNetworkServiceType.equals(PER_NETWORK_SERVICE_TYPE_1_NULL_NETWORK)) {
+                            @NonNull SocketKey socketKey) {
+                        final Pair<String, SocketKey> perSocketServiceType =
+                                Pair.create(serviceType, socketKey);
+                        if (perSocketServiceType.equals(PER_SOCKET_SERVICE_TYPE_1_NULL_NETWORK)) {
                             return mockServiceTypeClientType1NullNetwork;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_1_NETWORK_1)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_1_NETWORK_1)) {
                             return mockServiceTypeClientType1Network1;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_2_NULL_NETWORK)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_2_NULL_NETWORK)) {
                             return mockServiceTypeClientType2NullNetwork;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_2_NETWORK_1)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_2_NETWORK_1)) {
                             return mockServiceTypeClientType2Network1;
-                        } else if (perNetworkServiceType.equals(
-                                PER_NETWORK_SERVICE_TYPE_2_NETWORK_2)) {
+                        } else if (perSocketServiceType.equals(
+                                PER_SOCKET_SERVICE_TYPE_2_NETWORK_2)) {
                             return mockServiceTypeClientType2Network2;
                         }
                         return null;
@@ -156,7 +160,7 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(mockListenerOne, options);
 
         when(mockServiceTypeClientType1NullNetwork.stopSendAndReceive(mockListenerOne))
@@ -172,16 +176,16 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(mockListenerOne, options);
-        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).startSendAndReceive(mockListenerOne, options);
 
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, options);
-        runOnHandler(() -> callback2.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType2NullNetwork).startSendAndReceive(mockListenerTwo, options);
-        runOnHandler(() -> callback2.onSocketCreated(NETWORK_2));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network2).startSendAndReceive(mockListenerTwo, options);
     }
 
@@ -191,49 +195,48 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, options1);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(
                 mockListenerOne, options1);
-        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).startSendAndReceive(mockListenerOne, options1);
 
         final MdnsSearchOptions options2 =
                 MdnsSearchOptions.newBuilder().setNetwork(NETWORK_2).build();
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, options2);
-        runOnHandler(() -> callback2.onSocketCreated(NETWORK_2));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network2).startSendAndReceive(mockListenerTwo, options2);
 
         final MdnsPacket responseForServiceTypeOne = createMdnsPacket(SERVICE_TYPE_1);
-        final int ifIndex = 1;
         runOnHandler(() -> discoveryManager.onResponseReceived(
-                responseForServiceTypeOne, ifIndex, null /* network */));
+                responseForServiceTypeOne, SOCKET_KEY_NULL_NETWORK));
         // Packets for network null are only processed by the ServiceTypeClient for network null
         verify(mockServiceTypeClientType1NullNetwork).processResponse(responseForServiceTypeOne,
-                ifIndex, null /* network */);
+                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
         verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), anyInt(), any());
         verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), anyInt(), any());
 
         final MdnsPacket responseForServiceTypeTwo = createMdnsPacket(SERVICE_TYPE_2);
         runOnHandler(() -> discoveryManager.onResponseReceived(
-                responseForServiceTypeTwo, ifIndex, NETWORK_1));
+                responseForServiceTypeTwo, SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(), anyInt(),
-                eq(NETWORK_1));
+                eq(SOCKET_KEY_NETWORK_1.getNetwork()));
         verify(mockServiceTypeClientType1Network1).processResponse(responseForServiceTypeTwo,
-                ifIndex, NETWORK_1);
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
         verify(mockServiceTypeClientType2Network2, never()).processResponse(any(), anyInt(),
-                eq(NETWORK_1));
+                eq(SOCKET_KEY_NETWORK_1.getNetwork()));
 
         final MdnsPacket responseForSubtype =
                 createMdnsPacket("subtype._sub._googlecast._tcp.local");
         runOnHandler(() -> discoveryManager.onResponseReceived(
-                responseForSubtype, ifIndex, NETWORK_2));
-        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(
-                any(), anyInt(), eq(NETWORK_2));
-        verify(mockServiceTypeClientType1Network1, never()).processResponse(
-                any(), anyInt(), eq(NETWORK_2));
-        verify(mockServiceTypeClientType2Network2).processResponse(
-                responseForSubtype, ifIndex, NETWORK_2);
+                responseForSubtype, SOCKET_KEY_NETWORK_2));
+        verify(mockServiceTypeClientType1NullNetwork, never()).processResponse(any(), anyInt(),
+                eq(SOCKET_KEY_NETWORK_2.getNetwork()));
+        verify(mockServiceTypeClientType1Network1, never()).processResponse(any(), anyInt(),
+                eq(SOCKET_KEY_NETWORK_2.getNetwork()));
+        verify(mockServiceTypeClientType2Network2).processResponse(responseForSubtype,
+                SOCKET_KEY_NETWORK_2.getInterfaceIndex(), SOCKET_KEY_NETWORK_2.getNetwork());
     }
 
     @Test
@@ -243,55 +246,53 @@
                 MdnsSearchOptions.newBuilder().setNetwork(NETWORK_1).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, network1Options);
-        runOnHandler(() -> callback.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).startSendAndReceive(
                 mockListenerOne, network1Options);
 
         // Create a ServiceTypeClient for SERVICE_TYPE_2 and NETWORK_1
         final SocketCreationCallback callback2 = expectSocketCreationCallback(
                 SERVICE_TYPE_2, mockListenerTwo, network1Options);
-        runOnHandler(() -> callback2.onSocketCreated(NETWORK_1));
+        runOnHandler(() -> callback2.onSocketCreated(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType2Network1).startSendAndReceive(
                 mockListenerTwo, network1Options);
 
         // Receive a response, it should be processed on both clients.
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
-        final int ifIndex = 1;
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, NETWORK_1));
-        verify(mockServiceTypeClientType1Network1).processResponse(response, ifIndex, NETWORK_1);
-        verify(mockServiceTypeClientType2Network1).processResponse(response, ifIndex, NETWORK_1);
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
+        verify(mockServiceTypeClientType1Network1).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType2Network1).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
 
         // The first callback receives a notification that the network has been destroyed,
         // mockServiceTypeClientOne1 should send service removed notifications and remove from the
         // list of clients.
-        runOnHandler(() -> callback.onAllSocketsDestroyed(NETWORK_1));
+        runOnHandler(() -> callback.onAllSocketsDestroyed(SOCKET_KEY_NETWORK_1));
         verify(mockServiceTypeClientType1Network1).notifySocketDestroyed();
 
         // Receive a response again, it should be processed only on
         // mockServiceTypeClientType2Network1. Because the mockServiceTypeClientType1Network1 is
         // removed from the list of clients, it is no longer able to process responses.
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, NETWORK_1));
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
         // Still times(1) as a response was received once previously
-        verify(mockServiceTypeClientType1Network1, times(1))
-                .processResponse(response, ifIndex, NETWORK_1);
-        verify(mockServiceTypeClientType2Network1, times(2))
-                .processResponse(response, ifIndex, NETWORK_1);
+        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType2Network1, times(2)).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
 
         // The client for NETWORK_1 receives the callback that the NETWORK_2 has been destroyed,
         // mockServiceTypeClientTwo2 shouldn't send any notifications.
-        runOnHandler(() -> callback2.onAllSocketsDestroyed(NETWORK_2));
+        runOnHandler(() -> callback2.onAllSocketsDestroyed(SOCKET_KEY_NETWORK_2));
         verify(mockServiceTypeClientType2Network1, never()).notifySocketDestroyed();
 
         // Receive a response again, mockServiceTypeClientType2Network1 is still in the list of
         // clients, it's still able to process responses.
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, NETWORK_1));
-        verify(mockServiceTypeClientType1Network1, times(1))
-                .processResponse(response, ifIndex, NETWORK_1);
-        verify(mockServiceTypeClientType2Network1, times(3))
-                .processResponse(response, ifIndex, NETWORK_1);
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NETWORK_1));
+        verify(mockServiceTypeClientType1Network1, times(1)).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
+        verify(mockServiceTypeClientType2Network1, times(3)).processResponse(response,
+                SOCKET_KEY_NETWORK_1.getInterfaceIndex(), SOCKET_KEY_NETWORK_1.getNetwork());
     }
 
     @Test
@@ -301,27 +302,25 @@
                 MdnsSearchOptions.newBuilder().setNetwork(null /* network */).build();
         final SocketCreationCallback callback = expectSocketCreationCallback(
                 SERVICE_TYPE_1, mockListenerOne, network1Options);
-        runOnHandler(() -> callback.onSocketCreated(null /* network */));
+        runOnHandler(() -> callback.onSocketCreated(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).startSendAndReceive(
                 mockListenerOne, network1Options);
 
         // Receive a response, it should be processed on the client.
         final MdnsPacket response = createMdnsPacket(SERVICE_TYPE_1);
         final int ifIndex = 1;
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, null /* network */));
-        verify(mockServiceTypeClientType1NullNetwork).processResponse(
-                response, ifIndex, null /* network */);
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK));
+        verify(mockServiceTypeClientType1NullNetwork).processResponse(response,
+                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
 
-        runOnHandler(() -> callback.onAllSocketsDestroyed(null /* network */));
+        runOnHandler(() -> callback.onAllSocketsDestroyed(SOCKET_KEY_NULL_NETWORK));
         verify(mockServiceTypeClientType1NullNetwork).notifySocketDestroyed();
 
         // Receive a response again, it should not be processed.
-        runOnHandler(() -> discoveryManager.onResponseReceived(
-                response, ifIndex, null /* network */));
+        runOnHandler(() -> discoveryManager.onResponseReceived(response, SOCKET_KEY_NULL_NETWORK));
         // Still times(1) as a response was received once previously
-        verify(mockServiceTypeClientType1NullNetwork, times(1))
-                .processResponse(response, ifIndex, null /* network */);
+        verify(mockServiceTypeClientType1NullNetwork, times(1)).processResponse(response,
+                SOCKET_KEY_NULL_NETWORK.getInterfaceIndex(), SOCKET_KEY_NULL_NETWORK.getNetwork());
 
         // Unregister the listener, notifyNetworkUnrequested should be called but other stop methods
         // won't be call because the service type client was unregistered and destroyed. But those
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
index 87ba5d7..f7ef077 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsMultinetworkSocketClientTest.java
@@ -21,7 +21,6 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
@@ -68,12 +67,15 @@
     @Mock private MdnsServiceBrowserListener mListener;
     @Mock private MdnsSocketClientBase.Callback mCallback;
     @Mock private SocketCreationCallback mSocketCreationCallback;
+    @Mock private SocketKey mSocketKey;
     private MdnsMultinetworkSocketClient mSocketClient;
     private Handler mHandler;
 
     @Before
     public void setUp() throws SocketException {
         MockitoAnnotations.initMocks(this);
+        doReturn(mNetwork).when(mSocketKey).getNetwork();
+
         final HandlerThread thread = new HandlerThread("MdnsMultinetworkSocketClientTest");
         thread.start();
         mHandler = new Handler(thread.getLooper());
@@ -123,13 +125,17 @@
             doReturn(createEmptyNetworkInterface()).when(socket).getInterface();
         }
 
+        final SocketKey tetherSocketKey1 = mock(SocketKey.class);
+        final SocketKey tetherSocketKey2 = mock(SocketKey.class);
+        doReturn(null).when(tetherSocketKey1).getNetwork();
+        doReturn(null).when(tetherSocketKey2).getNetwork();
         // Notify socket created
-        callback.onSocketCreated(mNetwork, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
-        callback.onSocketCreated(null, tetherIfaceSock1, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(null);
-        callback.onSocketCreated(null, tetherIfaceSock2, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
+        callback.onSocketCreated(tetherSocketKey1, tetherIfaceSock1, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey1);
+        callback.onSocketCreated(tetherSocketKey2, tetherIfaceSock2, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(tetherSocketKey2);
 
         // Send packet to IPv4 with target network and verify sending has been called.
         mSocketClient.sendMulticastPacket(ipv4Packet, mNetwork);
@@ -164,8 +170,8 @@
 
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         // Notify socket created
-        callback.onSocketCreated(mNetwork, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
 
         final ArgumentCaptor<PacketHandler> handlerCaptor =
                 ArgumentCaptor.forClass(PacketHandler.class);
@@ -176,7 +182,7 @@
         handler.handlePacket(data, data.length, null /* src */);
         final ArgumentCaptor<MdnsPacket> responseCaptor =
                 ArgumentCaptor.forClass(MdnsPacket.class);
-        verify(mCallback).onResponseReceived(responseCaptor.capture(), anyInt(), any());
+        verify(mCallback).onResponseReceived(responseCaptor.capture(), any());
         final MdnsPacket response = responseCaptor.getValue();
         assertEquals(0, response.questions.size());
         assertEquals(0, response.additionalRecords.size());
@@ -214,11 +220,14 @@
         doReturn(createEmptyNetworkInterface()).when(socket2).getInterface();
         doReturn(createEmptyNetworkInterface()).when(socket3).getInterface();
 
-        callback.onSocketCreated(mNetwork, mSocket, List.of());
-        callback.onSocketCreated(null, socket2, List.of());
-        callback.onSocketCreated(null, socket3, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(mNetwork);
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+        final SocketKey socketKey2 = mock(SocketKey.class);
+        final SocketKey socketKey3 = mock(SocketKey.class);
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        callback.onSocketCreated(socketKey2, socket2, List.of());
+        callback.onSocketCreated(socketKey3, socket3, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
+        verify(mSocketCreationCallback).onSocketCreated(socketKey2);
+        verify(mSocketCreationCallback).onSocketCreated(socketKey3);
 
         // Send IPv4 packet on the non-null Network and verify sending has been called.
         mSocketClient.sendMulticastPacket(ipv4Packet, mNetwork);
@@ -241,11 +250,12 @@
         final SocketCallback callback2 = callback2Captor.getAllValues().get(1);
 
         // Notify socket created for all networks.
-        callback2.onSocketCreated(mNetwork, mSocket, List.of());
-        callback2.onSocketCreated(null, socket2, List.of());
-        callback2.onSocketCreated(null, socket3, List.of());
-        verify(socketCreationCb2).onSocketCreated(mNetwork);
-        verify(socketCreationCb2, times(2)).onSocketCreated(null);
+        callback2.onSocketCreated(mSocketKey, mSocket, List.of());
+        callback2.onSocketCreated(socketKey2, socket2, List.of());
+        callback2.onSocketCreated(socketKey3, socket3, List.of());
+        verify(socketCreationCb2).onSocketCreated(mSocketKey);
+        verify(socketCreationCb2).onSocketCreated(socketKey2);
+        verify(socketCreationCb2).onSocketCreated(socketKey3);
 
         // Send IPv4 packet to null network and verify sending to the 2 tethered interface sockets.
         mSocketClient.sendMulticastPacket(ipv4Packet, null);
@@ -286,17 +296,17 @@
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
 
-        callback.onSocketCreated(null /* network */, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(null);
-        callback.onSocketCreated(null /* network */, otherSocket, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
+        callback.onSocketCreated(mSocketKey, otherSocket, List.of());
+        verify(mSocketCreationCallback, times(2)).onSocketCreated(mSocketKey);
 
-        verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(null /* network */);
+        verify(mSocketCreationCallback, never()).onAllSocketsDestroyed(mSocketKey);
         mHandler.post(() -> mSocketClient.notifyNetworkUnrequested(mListener));
         HandlerUtils.waitForIdle(mHandler, DEFAULT_TIMEOUT);
 
         verify(mProvider).unrequestSocket(callback);
-        verify(mSocketCreationCallback).onAllSocketsDestroyed(null /* network */);
+        verify(mSocketCreationCallback).onAllSocketsDestroyed(mSocketKey);
     }
 
     @Test
@@ -306,15 +316,15 @@
         doReturn(createEmptyNetworkInterface()).when(mSocket).getInterface();
         doReturn(createEmptyNetworkInterface()).when(otherSocket).getInterface();
 
-        callback.onSocketCreated(null /* network */, mSocket, List.of());
-        verify(mSocketCreationCallback).onSocketCreated(null);
-        callback.onSocketCreated(null /* network */, otherSocket, List.of());
-        verify(mSocketCreationCallback, times(2)).onSocketCreated(null);
+        callback.onSocketCreated(mSocketKey, mSocket, List.of());
+        verify(mSocketCreationCallback).onSocketCreated(mSocketKey);
+        callback.onSocketCreated(mSocketKey, otherSocket, List.of());
+        verify(mSocketCreationCallback, times(2)).onSocketCreated(mSocketKey);
 
         // Notify socket destroyed
-        callback.onInterfaceDestroyed(null /* network */, mSocket);
+        callback.onInterfaceDestroyed(mSocketKey, mSocket);
         verifyNoMoreInteractions(mSocketCreationCallback);
-        callback.onInterfaceDestroyed(null /* network */, otherSocket);
-        verify(mSocketCreationCallback).onAllSocketsDestroyed(null /* network */);
+        callback.onInterfaceDestroyed(mSocketKey, otherSocket);
+        verify(mSocketCreationCallback).onAllSocketsDestroyed(mSocketKey);
     }
 }
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
index d1adecf..635a1d4 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsServiceTypeClientTests.java
@@ -118,6 +118,7 @@
     private FakeExecutor currentThreadExecutor = new FakeExecutor();
 
     private MdnsServiceTypeClient client;
+    private SocketKey socketKey;
 
     @Before
     @SuppressWarnings("DoNotMock")
@@ -128,6 +129,7 @@
         expectedIPv4Packets = new DatagramPacket[16];
         expectedIPv6Packets = new DatagramPacket[16];
         expectedSendFutures = new ScheduledFuture<?>[16];
+        socketKey = new SocketKey(mockNetwork, INTERFACE_INDEX);
 
         for (int i = 0; i < expectedSendFutures.length; ++i) {
             expectedIPv4Packets[i] = new DatagramPacket(buf, 0 /* offset */, 5 /* length */,
@@ -174,7 +176,7 @@
 
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -325,7 +327,7 @@
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
         QueryTaskConfig config = new QueryTaskConfig(
-                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, mockNetwork);
+                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, socketKey);
 
         // This is the first query. We will ask for unicast response.
         assertTrue(config.expectUnicastResponse);
@@ -354,7 +356,7 @@
         MdnsSearchOptions searchOptions =
                 MdnsSearchOptions.newBuilder().addSubtype("12345").setIsPassiveMode(false).build();
         QueryTaskConfig config = new QueryTaskConfig(
-                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, mockNetwork);
+                searchOptions.getSubtypes(), searchOptions.isPassiveMode(), 1, socketKey);
 
         // This is the first query. We will ask for unicast response.
         assertTrue(config.expectUnicastResponse);
@@ -508,9 +510,9 @@
 
         // Process a second response with a different port and updated text attributes.
         client.processResponse(createResponse(
-                "service-instance-1", ipV4Address, 5354,
-                /* subtype= */ "ABCDE",
-                Collections.singletonMap("key", "value"), TEST_TTL),
+                        "service-instance-1", ipV4Address, 5354,
+                        /* subtype= */ "ABCDE",
+                        Collections.singletonMap("key", "value"), TEST_TTL),
                 /* interfaceIndex= */ 20, mockNetwork);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
@@ -563,9 +565,9 @@
 
         // Process a second response with a different port and updated text attributes.
         client.processResponse(createResponse(
-                "service-instance-1", ipV6Address, 5354,
-                /* subtype= */ "ABCDE",
-                Collections.singletonMap("key", "value"), TEST_TTL),
+                        "service-instance-1", ipV6Address, 5354,
+                        /* subtype= */ "ABCDE",
+                        Collections.singletonMap("key", "value"), TEST_TTL),
                 /* interfaceIndex= */ 20, mockNetwork);
 
         // Verify onServiceNameDiscovered was called once for the initial response.
@@ -709,7 +711,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -750,7 +752,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -783,7 +785,7 @@
         final String serviceInstanceName = "service-instance-1";
         client =
                 new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                        mockDecoderClock, mockNetwork, mockSharedLog) {
+                        mockDecoderClock, socketKey, mockSharedLog) {
                     @Override
                     MdnsPacketWriter createMdnsPacketWriter() {
                         return mockPacketWriter;
@@ -835,8 +837,8 @@
 
         // Process the last response which is goodbye message (with the main type, not subtype).
         client.processResponse(createResponse(
-                serviceName, ipV6Address, 5354, SERVICE_TYPE_LABELS,
-                Collections.singletonMap("key", "value"), /* ptrTtlMillis= */ 0L),
+                        serviceName, ipV6Address, 5354, SERVICE_TYPE_LABELS,
+                        Collections.singletonMap("key", "value"), /* ptrTtlMillis= */ 0L),
                 INTERFACE_INDEX, mockNetwork);
 
         // Verify onServiceNameDiscovered was first called for the initial response.
@@ -908,7 +910,7 @@
     @Test
     public void testProcessResponse_Resolve() throws Exception {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
@@ -998,7 +1000,7 @@
     @Test
     public void testRenewTxtSrvInResolve() throws Exception {
         client = new MdnsServiceTypeClient(SERVICE_TYPE, mockSocketClient, currentThreadExecutor,
-                mockDecoderClock, mockNetwork, mockSharedLog);
+                mockDecoderClock, socketKey, mockSharedLog);
 
         final String instanceName = "service-instance";
         final String[] hostname = new String[] { "testhost "};
@@ -1102,7 +1104,7 @@
     @Test
     public void testProcessResponse_ResolveExcludesOtherServices() {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
@@ -1119,13 +1121,13 @@
 
         // Complete response from instanceName
         client.processResponse(createResponse(
-                requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                        requestedInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
                 INTERFACE_INDEX, mockNetwork);
 
         // Complete response from otherInstanceName
         client.processResponse(createResponse(
-                otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
+                        otherInstance, ipV4Address, 5353, SERVICE_TYPE_LABELS,
                         Collections.emptyMap() /* textAttributes */, TEST_TTL),
                 INTERFACE_INDEX, mockNetwork);
 
@@ -1166,7 +1168,7 @@
     @Test
     public void testProcessResponse_SubtypeDiscoveryLimitedToSubtype() {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String matchingInstance = "instance1";
         final String subtype = "_subtype";
@@ -1247,7 +1249,7 @@
     @Test
     public void testNotifySocketDestroyed() throws Exception {
         client = new MdnsServiceTypeClient(
-                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, mockNetwork, mockSharedLog);
+                SERVICE_TYPE, mockSocketClient, currentThreadExecutor, socketKey, mockSharedLog);
 
         final String requestedInstance = "instance1";
         final String otherInstance = "instance2";
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
index abb1747..e30c249 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketClientTests.java
@@ -23,6 +23,7 @@
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.argThat;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.never;
@@ -370,7 +371,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockCallback, timeout(TIMEOUT).atLeast(1))
-                .onResponseReceived(any(MdnsPacket.class), anyInt(), any());
+                .onResponseReceived(any(MdnsPacket.class), any(SocketKey.class));
     }
 
     @Test
@@ -379,7 +380,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockCallback, timeout(TIMEOUT).atLeastOnce())
-                .onResponseReceived(any(MdnsPacket.class), anyInt(), any());
+                .onResponseReceived(any(MdnsPacket.class), any(SocketKey.class));
 
         mdnsClient.stopDiscovery();
     }
@@ -513,7 +514,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockCallback, timeout(TIMEOUT).atLeastOnce())
-                .onResponseReceived(any(), eq(21), any());
+                .onResponseReceived(any(), argThat(key -> key.getInterfaceIndex() == 21));
     }
 
     @Test
@@ -536,6 +537,7 @@
         mdnsClient.startDiscovery();
 
         verify(mockMulticastSocket, never()).getInterfaceIndex();
-        verify(mockCallback, timeout(TIMEOUT).atLeast(1)).onResponseReceived(any(), eq(-1), any());
+        verify(mockCallback, timeout(TIMEOUT).atLeast(1))
+                .onResponseReceived(any(), argThat(key -> key.getInterfaceIndex() == -1));
     }
 }
\ No newline at end of file
diff --git a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
index 4ef64cb..0eac5ec 100644
--- a/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
+++ b/tests/unit/java/com/android/server/connectivity/mdns/MdnsSocketProviderTest.java
@@ -157,6 +157,7 @@
                 TETHERED_IFACE_NAME);
         doReturn(789).when(mDeps).getNetworkInterfaceIndexByName(
                 WIFI_P2P_IFACE_NAME);
+        doReturn(TETHERED_IFACE_IDX).when(mDeps).getInterfaceIndex(any());
         final HandlerThread thread = new HandlerThread("MdnsSocketProviderTest");
         thread.start();
         mHandler = new Handler(thread.getLooper());
@@ -227,30 +228,30 @@
 
     private class TestSocketCallback implements MdnsSocketProvider.SocketCallback {
         private class SocketEvent {
-            public final Network mNetwork;
+            public final SocketKey mSocketKey;
             public final List<LinkAddress> mAddresses;
 
-            SocketEvent(Network network, List<LinkAddress> addresses) {
-                mNetwork = network;
+            SocketEvent(SocketKey socketKey, List<LinkAddress> addresses) {
+                mSocketKey = socketKey;
                 mAddresses = Collections.unmodifiableList(addresses);
             }
         }
 
         private class SocketCreatedEvent extends SocketEvent {
-            SocketCreatedEvent(Network nw, List<LinkAddress> addresses) {
-                super(nw, addresses);
+            SocketCreatedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
+                super(socketKey, addresses);
             }
         }
 
         private class InterfaceDestroyedEvent extends SocketEvent {
-            InterfaceDestroyedEvent(Network nw, List<LinkAddress> addresses) {
-                super(nw, addresses);
+            InterfaceDestroyedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
+                super(socketKey, addresses);
             }
         }
 
         private class AddressesChangedEvent extends SocketEvent {
-            AddressesChangedEvent(Network nw, List<LinkAddress> addresses) {
-                super(nw, addresses);
+            AddressesChangedEvent(SocketKey socketKey, List<LinkAddress> addresses) {
+                super(socketKey, addresses);
             }
         }
 
@@ -258,27 +259,27 @@
                 new ArrayTrackRecord<SocketEvent>().newReadHead();
 
         @Override
-        public void onSocketCreated(Network network, MdnsInterfaceSocket socket,
+        public void onSocketCreated(SocketKey socketKey, MdnsInterfaceSocket socket,
                 List<LinkAddress> addresses) {
-            mHistory.add(new SocketCreatedEvent(network, addresses));
+            mHistory.add(new SocketCreatedEvent(socketKey, addresses));
         }
 
         @Override
-        public void onInterfaceDestroyed(Network network, MdnsInterfaceSocket socket) {
-            mHistory.add(new InterfaceDestroyedEvent(network, List.of()));
+        public void onInterfaceDestroyed(SocketKey socketKey, MdnsInterfaceSocket socket) {
+            mHistory.add(new InterfaceDestroyedEvent(socketKey, List.of()));
         }
 
         @Override
-        public void onAddressesChanged(Network network, MdnsInterfaceSocket socket,
+        public void onAddressesChanged(SocketKey socketKey, MdnsInterfaceSocket socket,
                 List<LinkAddress> addresses) {
-            mHistory.add(new AddressesChangedEvent(network, addresses));
+            mHistory.add(new AddressesChangedEvent(socketKey, addresses));
         }
 
         public void expectedSocketCreatedForNetwork(Network network, List<LinkAddress> addresses) {
             final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof SocketCreatedEvent);
-            assertEquals(network, event.mNetwork);
+            assertEquals(network, event.mSocketKey.getNetwork());
             assertEquals(addresses, event.mAddresses);
         }
 
@@ -286,7 +287,7 @@
             final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof InterfaceDestroyedEvent);
-            assertEquals(network, event.mNetwork);
+            assertEquals(network, event.mSocketKey.getNetwork());
         }
 
         public void expectedAddressesChangedForNetwork(Network network,
@@ -294,7 +295,7 @@
             final SocketEvent event = mHistory.poll(0L /* timeoutMs */, c -> true);
             assertNotNull(event);
             assertTrue(event instanceof AddressesChangedEvent);
-            assertEquals(network, event.mNetwork);
+            assertEquals(network, event.mSocketKey.getNetwork());
             assertEquals(event.mAddresses, addresses);
         }