Merge "Add more tests for setProfileNetworkPreferences"
diff --git a/bpf_progs/Android.bp b/bpf_progs/Android.bp
index 0e7b22d..6c78244 100644
--- a/bpf_progs/Android.bp
+++ b/bpf_progs/Android.bp
@@ -61,6 +61,7 @@
 bpf {
     name: "block.o",
     srcs: ["block.c"],
+    btf: true,
     cflags: [
         "-Wall",
         "-Werror",
@@ -71,6 +72,7 @@
 bpf {
     name: "dscp_policy.o",
     srcs: ["dscp_policy.c"],
+    btf: true,
     cflags: [
         "-Wall",
         "-Werror",
@@ -99,6 +101,7 @@
 bpf {
     name: "clatd.o",
     srcs: ["clatd.c"],
+    btf: true,
     cflags: [
         "-Wall",
         "-Werror",
@@ -112,6 +115,7 @@
 bpf {
     name: "netd.o",
     srcs: ["netd.c"],
+    btf: true,
     cflags: [
         "-Wall",
         "-Werror",
diff --git a/framework/src/android/net/QosCallbackException.java b/framework/src/android/net/QosCallbackException.java
index ed6eb15..b80cff4 100644
--- a/framework/src/android/net/QosCallbackException.java
+++ b/framework/src/android/net/QosCallbackException.java
@@ -46,8 +46,10 @@
             EX_TYPE_FILTER_NONE,
             EX_TYPE_FILTER_NETWORK_RELEASED,
             EX_TYPE_FILTER_SOCKET_NOT_BOUND,
+            EX_TYPE_FILTER_SOCKET_NOT_CONNECTED,
             EX_TYPE_FILTER_NOT_SUPPORTED,
             EX_TYPE_FILTER_SOCKET_LOCAL_ADDRESS_CHANGED,
+            EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED,
     })
     @Retention(RetentionPolicy.SOURCE)
     public @interface ExceptionType {}
@@ -65,10 +67,16 @@
     public static final int EX_TYPE_FILTER_SOCKET_NOT_BOUND = 2;
 
     /** {@hide} */
-    public static final int EX_TYPE_FILTER_NOT_SUPPORTED = 3;
+    public static final int EX_TYPE_FILTER_SOCKET_NOT_CONNECTED = 3;
 
     /** {@hide} */
-    public static final int EX_TYPE_FILTER_SOCKET_LOCAL_ADDRESS_CHANGED = 4;
+    public static final int EX_TYPE_FILTER_NOT_SUPPORTED = 4;
+
+    /** {@hide} */
+    public static final int EX_TYPE_FILTER_SOCKET_LOCAL_ADDRESS_CHANGED = 5;
+
+    /** {@hide} */
+    public static final int EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED = 6;
 
     /**
      * Creates exception based off of a type and message.  Not all types of exceptions accept a
@@ -83,12 +91,17 @@
                 return new QosCallbackException(new NetworkReleasedException());
             case EX_TYPE_FILTER_SOCKET_NOT_BOUND:
                 return new QosCallbackException(new SocketNotBoundException());
+            case EX_TYPE_FILTER_SOCKET_NOT_CONNECTED:
+                return new QosCallbackException(new SocketNotConnectedException());
             case EX_TYPE_FILTER_NOT_SUPPORTED:
                 return new QosCallbackException(new UnsupportedOperationException(
                         "This device does not support the specified filter"));
             case EX_TYPE_FILTER_SOCKET_LOCAL_ADDRESS_CHANGED:
                 return new QosCallbackException(
                         new SocketLocalAddressChangedException());
+            case EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED:
+                return new QosCallbackException(
+                        new SocketRemoteAddressChangedException());
             default:
                 Log.wtf(TAG, "create: No case setup for exception type: '" + type + "'");
                 return new QosCallbackException(
diff --git a/framework/src/android/net/QosFilter.java b/framework/src/android/net/QosFilter.java
index 5c1c3cc..b432644 100644
--- a/framework/src/android/net/QosFilter.java
+++ b/framework/src/android/net/QosFilter.java
@@ -90,5 +90,15 @@
      */
     public abstract boolean matchesRemoteAddress(@NonNull InetAddress address,
             int startPort, int endPort);
+
+    /**
+     * Determines whether or not the parameter will be matched with this filter.
+     *
+     * @param protocol the protocol such as TCP or UDP included in IP packet filter set of a QoS
+     *                 flow assigned on {@link Network}.
+     * @return whether the parameters match the socket type of the filter
+     * @hide
+     */
+    public abstract boolean matchesProtocol(int protocol);
 }
 
diff --git a/framework/src/android/net/QosFilterParcelable.java b/framework/src/android/net/QosFilterParcelable.java
index da3b2cf..6e71fa3 100644
--- a/framework/src/android/net/QosFilterParcelable.java
+++ b/framework/src/android/net/QosFilterParcelable.java
@@ -104,7 +104,7 @@
         if (mQosFilter instanceof QosSocketFilter) {
             dest.writeInt(QOS_SOCKET_FILTER);
             final QosSocketFilter qosSocketFilter = (QosSocketFilter) mQosFilter;
-            qosSocketFilter.getQosSocketInfo().writeToParcel(dest, 0);
+            qosSocketFilter.getQosSocketInfo().writeToParcelWithoutFd(dest, 0);
             return;
         }
         dest.writeInt(NO_FILTER_PRESENT);
diff --git a/framework/src/android/net/QosSocketFilter.java b/framework/src/android/net/QosSocketFilter.java
index 69da7f4..5ceeb67 100644
--- a/framework/src/android/net/QosSocketFilter.java
+++ b/framework/src/android/net/QosSocketFilter.java
@@ -18,6 +18,13 @@
 
 import static android.net.QosCallbackException.EX_TYPE_FILTER_NONE;
 import static android.net.QosCallbackException.EX_TYPE_FILTER_SOCKET_LOCAL_ADDRESS_CHANGED;
+import static android.net.QosCallbackException.EX_TYPE_FILTER_SOCKET_NOT_BOUND;
+import static android.net.QosCallbackException.EX_TYPE_FILTER_SOCKET_NOT_CONNECTED;
+import static android.net.QosCallbackException.EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED;
+import static android.system.OsConstants.IPPROTO_TCP;
+import static android.system.OsConstants.IPPROTO_UDP;
+import static android.system.OsConstants.SOCK_DGRAM;
+import static android.system.OsConstants.SOCK_STREAM;
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
@@ -74,19 +81,34 @@
      * 2. In the scenario that the socket is now bound to a different local address, which can
      *    happen in the case of UDP, then
      *    {@link QosCallbackException.EX_TYPE_FILTER_SOCKET_LOCAL_ADDRESS_CHANGED} is returned.
+     * 3. In the scenario that the UDP socket changed remote address, then
+     *    {@link QosCallbackException.EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED} is returned.
+     *
      * @return validation error code
      */
     @Override
     public int validate() {
-        final InetSocketAddress sa = getAddressFromFileDescriptor();
-        if (sa == null) {
-            return QosCallbackException.EX_TYPE_FILTER_SOCKET_NOT_BOUND;
+        final InetSocketAddress sa = getLocalAddressFromFileDescriptor();
+
+        if (sa == null || (sa.getAddress().isAnyLocalAddress() && sa.getPort() == 0)) {
+            return EX_TYPE_FILTER_SOCKET_NOT_BOUND;
         }
 
         if (!sa.equals(mQosSocketInfo.getLocalSocketAddress())) {
             return EX_TYPE_FILTER_SOCKET_LOCAL_ADDRESS_CHANGED;
         }
 
+        if (mQosSocketInfo.getRemoteSocketAddress() != null) {
+            final InetSocketAddress da = getRemoteAddressFromFileDescriptor();
+            if (da == null) {
+                return EX_TYPE_FILTER_SOCKET_NOT_CONNECTED;
+            }
+
+            if (!da.equals(mQosSocketInfo.getRemoteSocketAddress())) {
+                return EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED;
+            }
+        }
+
         return EX_TYPE_FILTER_NONE;
     }
 
@@ -98,17 +120,14 @@
      * @return the local address
      */
     @Nullable
-    private InetSocketAddress getAddressFromFileDescriptor() {
+    private InetSocketAddress getLocalAddressFromFileDescriptor() {
         final ParcelFileDescriptor parcelFileDescriptor = mQosSocketInfo.getParcelFileDescriptor();
-        if (parcelFileDescriptor == null) return null;
-
         final FileDescriptor fd = parcelFileDescriptor.getFileDescriptor();
-        if (fd == null) return null;
 
         final SocketAddress address;
         try {
             address = Os.getsockname(fd);
-        } catch (final ErrnoException e) {
+        } catch (ErrnoException e) {
             Log.e(TAG, "getAddressFromFileDescriptor: getLocalAddress exception", e);
             return null;
         }
@@ -119,6 +138,31 @@
     }
 
     /**
+     * The remote address of the socket's connected.
+     *
+     * <p>Note: If the socket is no longer connected, null is returned.
+     *
+     * @return the remote address
+     */
+    @Nullable
+    private InetSocketAddress getRemoteAddressFromFileDescriptor() {
+        final ParcelFileDescriptor parcelFileDescriptor = mQosSocketInfo.getParcelFileDescriptor();
+        final FileDescriptor fd = parcelFileDescriptor.getFileDescriptor();
+
+        final SocketAddress address;
+        try {
+            address = Os.getpeername(fd);
+        } catch (ErrnoException e) {
+            Log.e(TAG, "getAddressFromFileDescriptor: getRemoteAddress exception", e);
+            return null;
+        }
+        if (address instanceof InetSocketAddress) {
+            return (InetSocketAddress) address;
+        }
+        return null;
+    }
+
+    /**
      * The network used with this filter.
      *
      * @return the registered {@link Network}
@@ -156,6 +200,18 @@
     }
 
     /**
+     * @inheritDoc
+     */
+    @Override
+    public boolean matchesProtocol(final int protocol) {
+        if ((mQosSocketInfo.getSocketType() == SOCK_STREAM && protocol == IPPROTO_TCP)
+                || (mQosSocketInfo.getSocketType() == SOCK_DGRAM && protocol == IPPROTO_UDP)) {
+            return true;
+        }
+        return false;
+    }
+
+    /**
      * Called from {@link QosSocketFilter#matchesLocalAddress(InetAddress, int, int)}
      * and {@link QosSocketFilter#matchesRemoteAddress(InetAddress, int, int)} with the
      * filterSocketAddress coming from {@link QosSocketInfo#getLocalSocketAddress()}.
@@ -174,6 +230,7 @@
             final int startPort, final int endPort) {
         return startPort <= filterSocketAddress.getPort()
                 && endPort >= filterSocketAddress.getPort()
-                && filterSocketAddress.getAddress().equals(address);
+                && (address.isAnyLocalAddress()
+                        || filterSocketAddress.getAddress().equals(address));
     }
 }
diff --git a/framework/src/android/net/QosSocketInfo.java b/framework/src/android/net/QosSocketInfo.java
index 39c2f33..49ac22b 100644
--- a/framework/src/android/net/QosSocketInfo.java
+++ b/framework/src/android/net/QosSocketInfo.java
@@ -165,25 +165,28 @@
     /* Parcelable methods */
     private QosSocketInfo(final Parcel in) {
         mNetwork = Objects.requireNonNull(Network.CREATOR.createFromParcel(in));
-        mParcelFileDescriptor = ParcelFileDescriptor.CREATOR.createFromParcel(in);
+        final boolean withFd = in.readBoolean();
+        if (withFd) {
+            mParcelFileDescriptor = ParcelFileDescriptor.CREATOR.createFromParcel(in);
+        } else {
+            mParcelFileDescriptor = null;
+        }
 
-        final int localAddressLength = in.readInt();
-        mLocalSocketAddress = readSocketAddress(in, localAddressLength);
-
-        final int remoteAddressLength = in.readInt();
-        mRemoteSocketAddress = remoteAddressLength == 0 ? null
-                : readSocketAddress(in, remoteAddressLength);
+        mLocalSocketAddress = readSocketAddress(in);
+        mRemoteSocketAddress = readSocketAddress(in);
 
         mSocketType = in.readInt();
     }
 
-    private @NonNull InetSocketAddress readSocketAddress(final Parcel in, final int addressLength) {
-        final byte[] address = new byte[addressLength];
-        in.readByteArray(address);
+    private InetSocketAddress readSocketAddress(final Parcel in) {
+        final byte[] addrBytes = in.createByteArray();
+        if (addrBytes == null) {
+            return null;
+        }
         final int port = in.readInt();
 
         try {
-            return new InetSocketAddress(InetAddress.getByAddress(address), port);
+            return new InetSocketAddress(InetAddress.getByAddress(addrBytes), port);
         } catch (final UnknownHostException e) {
             /* This can never happen. UnknownHostException will never be thrown
                since the address provided is numeric and non-null. */
@@ -198,20 +201,35 @@
 
     @Override
     public void writeToParcel(@NonNull final Parcel dest, final int flags) {
-        mNetwork.writeToParcel(dest, 0);
-        mParcelFileDescriptor.writeToParcel(dest, 0);
+        writeToParcelInternal(dest, flags, /*includeFd=*/ true);
+    }
 
-        final byte[] localAddress = mLocalSocketAddress.getAddress().getAddress();
-        dest.writeInt(localAddress.length);
-        dest.writeByteArray(localAddress);
+    /**
+     * Used when sending QosSocketInfo to telephony, which does not need access to the socket FD.
+     * @hide
+     */
+    public void writeToParcelWithoutFd(@NonNull final Parcel dest, final int flags) {
+        writeToParcelInternal(dest, flags, /*includeFd=*/ false);
+    }
+
+    private void writeToParcelInternal(
+            @NonNull final Parcel dest, final int flags, boolean includeFd) {
+        mNetwork.writeToParcel(dest, 0);
+
+        if (includeFd) {
+            dest.writeBoolean(true);
+            mParcelFileDescriptor.writeToParcel(dest, 0);
+        } else {
+            dest.writeBoolean(false);
+        }
+
+        dest.writeByteArray(mLocalSocketAddress.getAddress().getAddress());
         dest.writeInt(mLocalSocketAddress.getPort());
 
         if (mRemoteSocketAddress == null) {
-            dest.writeInt(0);
+            dest.writeByteArray(null);
         } else {
-            final byte[] remoteAddress = mRemoteSocketAddress.getAddress().getAddress();
-            dest.writeInt(remoteAddress.length);
-            dest.writeByteArray(remoteAddress);
+            dest.writeByteArray(mRemoteSocketAddress.getAddress().getAddress());
             dest.writeInt(mRemoteSocketAddress.getPort());
         }
         dest.writeInt(mSocketType);
diff --git a/framework/src/android/net/SocketNotConnectedException.java b/framework/src/android/net/SocketNotConnectedException.java
new file mode 100644
index 0000000..fa2a615
--- /dev/null
+++ b/framework/src/android/net/SocketNotConnectedException.java
@@ -0,0 +1,29 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net;
+
+/**
+ * Thrown when a previously bound socket becomes unbound.
+ *
+ * @hide
+ */
+public class SocketNotConnectedException extends Exception {
+    /** @hide */
+    public SocketNotConnectedException() {
+        super("The socket is not connected");
+    }
+}
diff --git a/framework/src/android/net/SocketRemoteAddressChangedException.java b/framework/src/android/net/SocketRemoteAddressChangedException.java
new file mode 100644
index 0000000..ecaeebc
--- /dev/null
+++ b/framework/src/android/net/SocketRemoteAddressChangedException.java
@@ -0,0 +1,29 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net;
+
+/**
+ * Thrown when the local address of the socket has changed.
+ *
+ * @hide
+ */
+public class SocketRemoteAddressChangedException extends Exception {
+    /** @hide */
+    public SocketRemoteAddressChangedException() {
+        super("The remote address of the socket changed");
+    }
+}
diff --git a/service-t/src/com/android/server/net/NetworkStatsObservers.java b/service-t/src/com/android/server/net/NetworkStatsObservers.java
index fdfc893..d974a3b 100644
--- a/service-t/src/com/android/server/net/NetworkStatsObservers.java
+++ b/service-t/src/com/android/server/net/NetworkStatsObservers.java
@@ -18,6 +18,7 @@
 
 import static android.app.usage.NetworkStatsManager.MIN_THRESHOLD_BYTES;
 
+import android.annotation.NonNull;
 import android.app.usage.NetworkStatsManager;
 import android.content.Context;
 import android.content.pm.PackageManager;
@@ -38,6 +39,7 @@
 import android.os.Process;
 import android.os.RemoteException;
 import android.util.ArrayMap;
+import android.util.IndentingPrintWriter;
 import android.util.Log;
 import android.util.SparseArray;
 
@@ -52,6 +54,7 @@
  */
 class NetworkStatsObservers {
     private static final String TAG = "NetworkStatsObservers";
+    private static final boolean LOG = true;
     private static final boolean LOGV = false;
 
     private static final int MSG_REGISTER = 1;
@@ -77,13 +80,15 @@
      *
      * @return the normalized request wrapped within {@link RequestInfo}.
      */
-    public DataUsageRequest register(Context context, DataUsageRequest inputRequest,
-            IUsageCallback callback, int callingUid, @NetworkStatsAccess.Level int accessLevel) {
+    public DataUsageRequest register(@NonNull Context context,
+            @NonNull DataUsageRequest inputRequest, @NonNull IUsageCallback callback,
+            int callingPid, int callingUid, @NonNull String callingPackage,
+            @NetworkStatsAccess.Level int accessLevel) {
         DataUsageRequest request = buildRequest(context, inputRequest, callingUid);
-        RequestInfo requestInfo = buildRequestInfo(request, callback, callingUid,
-                accessLevel);
+        RequestInfo requestInfo = buildRequestInfo(request, callback, callingPid, callingUid,
+                callingPackage, accessLevel);
 
-        if (LOGV) Log.v(TAG, "Registering observer for " + request);
+        if (LOG) Log.d(TAG, "Registering observer for " + requestInfo);
         getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo));
         return request;
     }
@@ -172,7 +177,7 @@
         RequestInfo requestInfo;
         requestInfo = mDataUsageRequests.get(request.requestId);
         if (requestInfo == null) {
-            if (LOGV) Log.v(TAG, "Trying to unregister unknown request " + request);
+            if (LOG) Log.d(TAG, "Trying to unregister unknown request " + request);
             return;
         }
         if (Process.SYSTEM_UID != callingUid && requestInfo.mCallingUid != callingUid) {
@@ -180,7 +185,7 @@
             return;
         }
 
-        if (LOGV) Log.v(TAG, "Unregistering " + request);
+        if (LOG) Log.d(TAG, "Unregistering " + requestInfo);
         mDataUsageRequests.remove(request.requestId);
         requestInfo.unlinkDeathRecipient();
         requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED);
@@ -214,18 +219,19 @@
     }
 
     private RequestInfo buildRequestInfo(DataUsageRequest request, IUsageCallback callback,
-            int callingUid, @NetworkStatsAccess.Level int accessLevel) {
+            int callingPid, int callingUid, @NonNull String callingPackage,
+            @NetworkStatsAccess.Level int accessLevel) {
         if (accessLevel <= NetworkStatsAccess.Level.USER) {
-            return new UserUsageRequestInfo(this, request, callback, callingUid,
-                    accessLevel);
+            return new UserUsageRequestInfo(this, request, callback, callingPid,
+                    callingUid, callingPackage, accessLevel);
         } else {
             // Safety check in case a new access level is added and we forgot to update this
             if (accessLevel < NetworkStatsAccess.Level.DEVICESUMMARY) {
                 throw new IllegalArgumentException(
                         "accessLevel " + accessLevel + " is less than DEVICESUMMARY.");
             }
-            return new NetworkUsageRequestInfo(this, request, callback, callingUid,
-                    accessLevel);
+            return new NetworkUsageRequestInfo(this, request, callback, callingPid,
+                    callingUid, callingPackage, accessLevel);
         }
     }
 
@@ -237,18 +243,22 @@
         private final NetworkStatsObservers mStatsObserver;
         protected final DataUsageRequest mRequest;
         private final IUsageCallback mCallback;
+        protected final int mCallingPid;
         protected final int mCallingUid;
+        protected final String mCallingPackage;
         protected final @NetworkStatsAccess.Level int mAccessLevel;
         protected NetworkStatsRecorder mRecorder;
         protected NetworkStatsCollection mCollection;
 
         RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
-                IUsageCallback callback, int callingUid,
-                    @NetworkStatsAccess.Level int accessLevel) {
+                IUsageCallback callback, int callingPid, int callingUid,
+                @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) {
             mStatsObserver = statsObserver;
             mRequest = request;
             mCallback = callback;
+            mCallingPid = callingPid;
             mCallingUid = callingUid;
+            mCallingPackage = callingPackage;
             mAccessLevel = accessLevel;
 
             try {
@@ -269,7 +279,8 @@
 
         @Override
         public String toString() {
-            return "RequestInfo from uid:" + mCallingUid
+            return "RequestInfo from pid/uid:" + mCallingPid + "/" + mCallingUid
+                    + "(" + mCallingPackage + ")"
                     + " for " + mRequest + " accessLevel:" + mAccessLevel;
         }
 
@@ -338,9 +349,10 @@
 
     private static class NetworkUsageRequestInfo extends RequestInfo {
         NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
-                IUsageCallback callback, int callingUid,
-                    @NetworkStatsAccess.Level int accessLevel) {
-            super(statsObserver, request, callback, callingUid, accessLevel);
+                IUsageCallback callback, int callingPid, int callingUid,
+                @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) {
+            super(statsObserver, request, callback, callingPid, callingUid, callingPackage,
+                    accessLevel);
         }
 
         @Override
@@ -380,9 +392,10 @@
 
     private static class UserUsageRequestInfo extends RequestInfo {
         UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
-                    IUsageCallback callback, int callingUid,
-                    @NetworkStatsAccess.Level int accessLevel) {
-            super(statsObserver, request, callback, callingUid, accessLevel);
+                IUsageCallback callback, int callingPid, int callingUid,
+                @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) {
+            super(statsObserver, request, callback, callingPid, callingUid,
+                    callingPackage, accessLevel);
         }
 
         @Override
@@ -448,4 +461,10 @@
             mCurrentTime = currentTime;
         }
     }
+
+    public void dump(IndentingPrintWriter pw) {
+        for (int i = 0; i < mDataUsageRequests.size(); i++) {
+            pw.println(mDataUsageRequests.valueAt(i));
+        }
+    }
 }
diff --git a/service-t/src/com/android/server/net/NetworkStatsService.java b/service-t/src/com/android/server/net/NetworkStatsService.java
index 82b1fb5..7c167ed 100644
--- a/service-t/src/com/android/server/net/NetworkStatsService.java
+++ b/service-t/src/com/android/server/net/NetworkStatsService.java
@@ -159,8 +159,11 @@
 import java.time.ZoneOffset;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.Executor;
@@ -374,9 +377,9 @@
 
     private long mLastStatsSessionPoll;
 
-    /** Map from UID to number of opened sessions */
-    @GuardedBy("mOpenSessionCallsPerUid")
-    private final SparseIntArray mOpenSessionCallsPerUid = new SparseIntArray();
+    /** Map from key {@code OpenSessionKey} to count of opened sessions */
+    @GuardedBy("mOpenSessionCallsPerCaller")
+    private final HashMap<OpenSessionKey, Integer> mOpenSessionCallsPerCaller = new HashMap<>();
 
     private final static int DUMP_STATS_SESSION_COUNT = 20;
 
@@ -407,6 +410,48 @@
                 Clock.systemUTC());
     }
 
+    /**
+     * This class is a key that used in {@code mOpenSessionCallsPerCaller} to identify the count of
+     * the caller.
+     */
+    private static class OpenSessionKey {
+        public final int mPid;
+        public final int mUid;
+        public final String mPackage;
+
+        OpenSessionKey(int pid, int uid, @NonNull String packageName) {
+            mPid = pid;
+            mUid = uid;
+            mPackage = packageName;
+        }
+
+        @Override
+        public String toString() {
+            final StringBuilder sb = new StringBuilder();
+            sb.append("{");
+            sb.append("pid=").append(mPid).append(",");
+            sb.append("uid=").append(mUid).append(",");
+            sb.append("package=").append(mPackage);
+            sb.append("}");
+            return sb.toString();
+        }
+
+        @Override
+        public boolean equals(@NonNull Object o) {
+            if (this == o) return true;
+            if (o.getClass() != getClass()) return false;
+
+            final OpenSessionKey key = (OpenSessionKey) o;
+            return mPid == key.mPid && mUid == key.mUid
+                    && TextUtils.equals(mPackage, key.mPackage);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(mPid, mUid, mPackage);
+        }
+    }
+
     private final class NetworkStatsHandler extends Handler {
         NetworkStatsHandler(@NonNull Looper looper) {
             super(looper);
@@ -794,24 +839,29 @@
         return openSessionInternal(flags, callingPackage);
     }
 
-    private boolean isRateLimitedForPoll(int callingUid) {
-        if (callingUid == android.os.Process.SYSTEM_UID) {
-            return false;
-        }
-
+    private boolean isRateLimitedForPoll(@NonNull OpenSessionKey key) {
         final long lastCallTime;
         final long now = SystemClock.elapsedRealtime();
-        synchronized (mOpenSessionCallsPerUid) {
-            int calls = mOpenSessionCallsPerUid.get(callingUid, 0);
-            mOpenSessionCallsPerUid.put(callingUid, calls + 1);
+
+        synchronized (mOpenSessionCallsPerCaller) {
+            Integer calls = mOpenSessionCallsPerCaller.get(key);
+            if (calls == null) {
+                mOpenSessionCallsPerCaller.put((key), 1);
+            } else {
+                mOpenSessionCallsPerCaller.put(key, Integer.sum(calls, 1));
+            }
             lastCallTime = mLastStatsSessionPoll;
             mLastStatsSessionPoll = now;
         }
 
+        if (key.mUid == android.os.Process.SYSTEM_UID) {
+            return false;
+        }
+
         return now - lastCallTime < POLL_RATE_LIMIT_MS;
     }
 
-    private int restrictFlagsForCaller(int flags) {
+    private int restrictFlagsForCaller(int flags, @NonNull String callingPackage) {
         // All non-privileged callers are not allowed to turn off POLL_ON_OPEN.
         final boolean isPrivileged = PermissionUtils.checkAnyPermissionOf(mContext,
                 NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK,
@@ -820,15 +870,17 @@
             flags |= NetworkStatsManager.FLAG_POLL_ON_OPEN;
         }
         // Non-system uids are rate limited for POLL_ON_OPEN.
+        final int callingPid = Binder.getCallingPid();
         final int callingUid = Binder.getCallingUid();
-        flags = isRateLimitedForPoll(callingUid)
+        final OpenSessionKey key = new OpenSessionKey(callingPid, callingUid, callingPackage);
+        flags = isRateLimitedForPoll(key)
                 ? flags & (~NetworkStatsManager.FLAG_POLL_ON_OPEN)
                 : flags;
         return flags;
     }
 
     private INetworkStatsSession openSessionInternal(final int flags, final String callingPackage) {
-        final int restrictedFlags = restrictFlagsForCaller(flags);
+        final int restrictedFlags = restrictFlagsForCaller(flags, callingPackage);
         if ((restrictedFlags & (NetworkStatsManager.FLAG_POLL_ON_OPEN
                 | NetworkStatsManager.FLAG_POLL_FORCE)) != 0) {
             final long ident = Binder.clearCallingIdentity();
@@ -1279,13 +1331,14 @@
         Objects.requireNonNull(request.template, "NetworkTemplate is null");
         Objects.requireNonNull(callback, "callback is null");
 
-        int callingUid = Binder.getCallingUid();
+        final int callingPid = Binder.getCallingPid();
+        final int callingUid = Binder.getCallingUid();
         @NetworkStatsAccess.Level int accessLevel = checkAccessLevel(callingPackage);
         DataUsageRequest normalizedRequest;
         final long token = Binder.clearCallingIdentity();
         try {
             normalizedRequest = mStatsObservers.register(mContext,
-                    request, callback, callingUid, accessLevel);
+                    request, callback, callingPid, callingUid, callingPackage, accessLevel);
         } finally {
             Binder.restoreCallingIdentity(token);
         }
@@ -2060,25 +2113,21 @@
             pw.decreaseIndent();
 
             // Get the top openSession callers
-            final SparseIntArray calls;
-            synchronized (mOpenSessionCallsPerUid) {
-                calls = mOpenSessionCallsPerUid.clone();
+            final HashMap calls;
+            synchronized (mOpenSessionCallsPerCaller) {
+                calls = new HashMap<>(mOpenSessionCallsPerCaller);
             }
-
-            final int N = calls.size();
-            final long[] values = new long[N];
-            for (int j = 0; j < N; j++) {
-                values[j] = ((long) calls.valueAt(j) << 32) | calls.keyAt(j);
-            }
-            Arrays.sort(values);
-
-            pw.println("Top openSession callers (uid=count):");
+            final List<Map.Entry<OpenSessionKey, Integer>> list = new ArrayList<>(calls.entrySet());
+            Collections.sort(list,
+                    (left, right) -> Integer.compare(left.getValue(), right.getValue()));
+            final int num = list.size();
+            final int end = Math.max(0, num - DUMP_STATS_SESSION_COUNT);
+            pw.println("Top openSession callers:");
             pw.increaseIndent();
-            final int end = Math.max(0, N - DUMP_STATS_SESSION_COUNT);
-            for (int j = N - 1; j >= end; j--) {
-                final int uid = (int) (values[j] & 0xffffffff);
-                final int count = (int) (values[j] >> 32);
-                pw.print(uid); pw.print("="); pw.println(count);
+            for (int j = num - 1; j >= end; j--) {
+                final Map.Entry<OpenSessionKey, Integer> entry = list.get(j);
+                pw.print(entry.getKey()); pw.print("="); pw.println(entry.getValue());
+
             }
             pw.decreaseIndent();
             pw.println();
@@ -2098,6 +2147,13 @@
                 }
             });
             pw.decreaseIndent();
+            pw.println();
+
+            pw.println("Stats Observers:");
+            pw.increaseIndent();
+            mStatsObservers.dump(pw);
+            pw.decreaseIndent();
+            pw.println();
 
             pw.println("Dev stats:");
             pw.increaseIndent();
diff --git a/service/src/com/android/server/connectivity/QosCallbackAgentConnection.java b/service/src/com/android/server/connectivity/QosCallbackAgentConnection.java
index 534dbe7..e682026 100644
--- a/service/src/com/android/server/connectivity/QosCallbackAgentConnection.java
+++ b/service/src/com/android/server/connectivity/QosCallbackAgentConnection.java
@@ -30,6 +30,8 @@
 import android.telephony.data.NrQosSessionAttributes;
 import android.util.Log;
 
+import com.android.modules.utils.build.SdkLevel;
+
 import java.util.Objects;
 
 /**
@@ -149,6 +151,7 @@
 
     void sendEventEpsQosSessionAvailable(final QosSession session,
             final EpsBearerQosSessionAttributes attributes) {
+        if (!validateOrSendErrorAndUnregister()) return;
         try {
             if (DBG) log("sendEventEpsQosSessionAvailable: sending...");
             mCallback.onQosEpsBearerSessionAvailable(session, attributes);
@@ -159,6 +162,7 @@
 
     void sendEventNrQosSessionAvailable(final QosSession session,
             final NrQosSessionAttributes attributes) {
+        if (!validateOrSendErrorAndUnregister()) return;
         try {
             if (DBG) log("sendEventNrQosSessionAvailable: sending...");
             mCallback.onNrQosSessionAvailable(session, attributes);
@@ -168,6 +172,7 @@
     }
 
     void sendEventQosSessionLost(@NonNull final QosSession session) {
+        if (!validateOrSendErrorAndUnregister()) return;
         try {
             if (DBG) log("sendEventQosSessionLost: sending...");
             mCallback.onQosSessionLost(session);
@@ -185,6 +190,21 @@
         }
     }
 
+    private boolean validateOrSendErrorAndUnregister() {
+        final int exceptionType = mFilter.validate();
+        if (exceptionType != EX_TYPE_FILTER_NONE) {
+             log("validation fail before sending QosCallback.");
+             // Error callback is returned from Android T to prevent any disruption of application
+             // running on Android S.
+             if (SdkLevel.isAtLeastT()) {
+                sendEventQosCallbackError(exceptionType);
+                mQosCallbackTracker.unregisterCallback(mCallback);
+            }
+            return false;
+        }
+        return true;
+    }
+
     private static void log(@NonNull final String msg) {
         Log.d(TAG, msg);
     }
diff --git a/tests/common/java/android/net/netstats/NetworkStatsCollectionTest.kt b/tests/common/java/android/net/netstats/NetworkStatsCollectionTest.kt
index ca0e5ed..368a519 100644
--- a/tests/common/java/android/net/netstats/NetworkStatsCollectionTest.kt
+++ b/tests/common/java/android/net/netstats/NetworkStatsCollectionTest.kt
@@ -16,7 +16,7 @@
 
 package android.net.netstats
 
-import android.net.NetworkIdentitySet
+import android.net.NetworkIdentity
 import android.net.NetworkStatsCollection
 import android.net.NetworkStatsHistory
 import androidx.test.filters.SmallTest
@@ -40,7 +40,7 @@
 
     @Test
     fun testBuilder() {
-        val ident = NetworkIdentitySet()
+        val ident = setOf<NetworkIdentity>()
         val key1 = NetworkStatsCollection.Key(ident, /* uid */ 0, /* set */ 0, /* tag */ 0)
         val key2 = NetworkStatsCollection.Key(ident, /* uid */ 1, /* set */ 0, /* tag */ 0)
         val bucketDuration = 10L
@@ -63,4 +63,4 @@
         val actualHistory = actualEntries[key1] ?: fail("There should be an entry for $key1")
         assertEquals(history1.entries, actualHistory.entries)
     }
-}
\ No newline at end of file
+}
diff --git a/tests/cts/hostside/app2/Android.bp b/tests/cts/hostside/app2/Android.bp
index 01c8cd2..edfaf9f 100644
--- a/tests/cts/hostside/app2/Android.bp
+++ b/tests/cts/hostside/app2/Android.bp
@@ -23,6 +23,7 @@
     defaults: ["cts_support_defaults"],
     sdk_version: "test_current",
     static_libs: [
+        "androidx.annotation_annotation",
         "CtsHostsideNetworkTestsAidl",
         "NetworkStackApiStableShims",
     ],
diff --git a/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyActivity.java b/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyActivity.java
index 08cdea7..51acfdf 100644
--- a/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyActivity.java
+++ b/tests/cts/hostside/app2/src/com/android/cts/net/hostside/app2/MyActivity.java
@@ -30,6 +30,8 @@
 import android.os.RemoteException;
 import android.util.Log;
 
+import androidx.annotation.GuardedBy;
+
 import com.android.cts.net.hostside.INetworkStateObserver;
 
 /**
@@ -37,6 +39,7 @@
  */
 public class MyActivity extends Activity {
 
+    @GuardedBy("this")
     private BroadcastReceiver finishCommandReceiver = null;
 
     @Override
@@ -47,8 +50,11 @@
 
     @Override
     public void finish() {
-        if (finishCommandReceiver != null) {
-            unregisterReceiver(finishCommandReceiver);
+        synchronized (this) {
+            if (finishCommandReceiver != null) {
+                unregisterReceiver(finishCommandReceiver);
+                finishCommandReceiver = null;
+            }
         }
         super.finish();
     }
@@ -71,14 +77,17 @@
         super.onResume();
         Log.d(TAG, "MyActivity.onResume(): " + getIntent());
         Common.notifyNetworkStateObserver(this, getIntent(), TYPE_COMPONENT_ACTIVTY);
-        finishCommandReceiver = new BroadcastReceiver() {
-            @Override
-            public void onReceive(Context context, Intent intent) {
-                Log.d(TAG, "Finishing MyActivity");
-                MyActivity.this.finish();
-            }
-        };
-        registerReceiver(finishCommandReceiver, new IntentFilter(ACTION_FINISH_ACTIVITY));
+        synchronized (this) {
+            finishCommandReceiver = new BroadcastReceiver() {
+                @Override
+                public void onReceive(Context context, Intent intent) {
+                    Log.d(TAG, "Finishing MyActivity");
+                    MyActivity.this.finish();
+                }
+            };
+            registerReceiver(finishCommandReceiver, new IntentFilter(ACTION_FINISH_ACTIVITY),
+                    Context.RECEIVER_EXPORTED);
+        }
     }
 
     @Override
diff --git a/tests/unit/java/android/net/QosSocketFilterTest.java b/tests/unit/java/android/net/QosSocketFilterTest.java
index 91f2cdd..6820b40 100644
--- a/tests/unit/java/android/net/QosSocketFilterTest.java
+++ b/tests/unit/java/android/net/QosSocketFilterTest.java
@@ -16,8 +16,17 @@
 
 package android.net;
 
-import static junit.framework.Assert.assertFalse;
-import static junit.framework.Assert.assertTrue;
+import static android.net.QosCallbackException.EX_TYPE_FILTER_NONE;
+import static android.net.QosCallbackException.EX_TYPE_FILTER_SOCKET_LOCAL_ADDRESS_CHANGED;
+import static android.net.QosCallbackException.EX_TYPE_FILTER_SOCKET_NOT_BOUND;
+import static android.net.QosCallbackException.EX_TYPE_FILTER_SOCKET_NOT_CONNECTED;
+import static android.net.QosCallbackException.EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED;
+import static android.system.OsConstants.IPPROTO_TCP;
+import static android.system.OsConstants.IPPROTO_UDP;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 
 import android.os.Build;
 
@@ -29,6 +38,7 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
+import java.net.DatagramSocket;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
 
@@ -36,14 +46,14 @@
 @SmallTest
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.R)
 public class QosSocketFilterTest {
-
+    private static final int TEST_NET_ID = 1777;
+    private final Network mNetwork = new Network(TEST_NET_ID);
     @Test
     public void testPortExactMatch() {
         final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4");
         final InetAddress addressB = InetAddresses.parseNumericAddress("1.2.3.4");
         assertTrue(QosSocketFilter.matchesAddress(
                 new InetSocketAddress(addressA, 10), addressB, 10, 10));
-
     }
 
     @Test
@@ -77,5 +87,90 @@
         assertFalse(QosSocketFilter.matchesAddress(
                 new InetSocketAddress(addressA, 10), addressB, 10, 10));
     }
+
+    @Test
+    public void testAddressMatchWithAnyLocalAddresses() {
+        final InetAddress addressA = InetAddresses.parseNumericAddress("1.2.3.4");
+        final InetAddress addressB = InetAddresses.parseNumericAddress("0.0.0.0");
+        assertTrue(QosSocketFilter.matchesAddress(
+                new InetSocketAddress(addressA, 10), addressB, 10, 10));
+        assertFalse(QosSocketFilter.matchesAddress(
+                new InetSocketAddress(addressB, 10), addressA, 10, 10));
+    }
+
+    @Test
+    public void testProtocolMatch() throws Exception {
+        DatagramSocket socket = new DatagramSocket(new InetSocketAddress("127.0.0.1", 0));
+        socket.connect(new InetSocketAddress("127.0.0.1", socket.getLocalPort() + 10));
+        DatagramSocket socketV6 = new DatagramSocket(new InetSocketAddress("::1", 0));
+        socketV6.connect(new InetSocketAddress("::1", socketV6.getLocalPort() + 10));
+        QosSocketInfo socketInfo = new QosSocketInfo(mNetwork, socket);
+        QosSocketFilter socketFilter = new QosSocketFilter(socketInfo);
+        QosSocketInfo socketInfo6 = new QosSocketInfo(mNetwork, socketV6);
+        QosSocketFilter socketFilter6 = new QosSocketFilter(socketInfo6);
+        assertTrue(socketFilter.matchesProtocol(IPPROTO_UDP));
+        assertTrue(socketFilter6.matchesProtocol(IPPROTO_UDP));
+        assertFalse(socketFilter.matchesProtocol(IPPROTO_TCP));
+        assertFalse(socketFilter6.matchesProtocol(IPPROTO_TCP));
+        socket.close();
+        socketV6.close();
+    }
+
+    @Test
+    public void testValidate() throws Exception {
+        DatagramSocket socket = new DatagramSocket(new InetSocketAddress("127.0.0.1", 0));
+        socket.connect(new InetSocketAddress("127.0.0.1", socket.getLocalPort() + 7));
+        DatagramSocket socketV6 = new DatagramSocket(new InetSocketAddress("::1", 0));
+
+        QosSocketInfo socketInfo = new QosSocketInfo(mNetwork, socket);
+        QosSocketFilter socketFilter = new QosSocketFilter(socketInfo);
+        QosSocketInfo socketInfo6 = new QosSocketInfo(mNetwork, socketV6);
+        QosSocketFilter socketFilter6 = new QosSocketFilter(socketInfo6);
+        assertEquals(EX_TYPE_FILTER_NONE, socketFilter.validate());
+        assertEquals(EX_TYPE_FILTER_NONE, socketFilter6.validate());
+        socket.close();
+        socketV6.close();
+    }
+
+    @Test
+    public void testValidateUnbind() throws Exception {
+        DatagramSocket socket;
+        socket = new DatagramSocket(null);
+        QosSocketInfo socketInfo = new QosSocketInfo(mNetwork, socket);
+        QosSocketFilter socketFilter = new QosSocketFilter(socketInfo);
+        assertEquals(EX_TYPE_FILTER_SOCKET_NOT_BOUND, socketFilter.validate());
+        socket.close();
+    }
+
+    @Test
+    public void testValidateLocalAddressChanged() throws Exception {
+        DatagramSocket socket = new DatagramSocket(null);
+        DatagramSocket socket6 = new DatagramSocket(null);
+        QosSocketInfo socketInfo = new QosSocketInfo(mNetwork, socket);
+        QosSocketFilter socketFilter = new QosSocketFilter(socketInfo);
+        QosSocketInfo socketInfo6 = new QosSocketInfo(mNetwork, socket6);
+        QosSocketFilter socketFilter6 = new QosSocketFilter(socketInfo6);
+        socket.bind(new InetSocketAddress("127.0.0.1", 0));
+        socket6.bind(new InetSocketAddress("::1", 0));
+        assertEquals(EX_TYPE_FILTER_SOCKET_LOCAL_ADDRESS_CHANGED, socketFilter.validate());
+        assertEquals(EX_TYPE_FILTER_SOCKET_LOCAL_ADDRESS_CHANGED, socketFilter6.validate());
+        socket.close();
+        socket6.close();
+    }
+
+    @Test
+    public void testValidateRemoteAddressChanged() throws Exception {
+        DatagramSocket socket;
+        socket = new DatagramSocket(new InetSocketAddress("127.0.0.1", 53137));
+        socket.connect(new InetSocketAddress("127.0.0.1", socket.getLocalPort() + 11));
+        QosSocketInfo socketInfo = new QosSocketInfo(mNetwork, socket);
+        QosSocketFilter socketFilter = new QosSocketFilter(socketInfo);
+        assertEquals(EX_TYPE_FILTER_NONE, socketFilter.validate());
+        socket.disconnect();
+        assertEquals(EX_TYPE_FILTER_SOCKET_NOT_CONNECTED, socketFilter.validate());
+        socket.connect(new InetSocketAddress("127.0.0.1", socket.getLocalPort() + 13));
+        assertEquals(EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED, socketFilter.validate());
+        socket.close();
+    }
 }
 
diff --git a/tests/unit/java/android/net/QosSocketInfoTest.java b/tests/unit/java/android/net/QosSocketInfoTest.java
new file mode 100644
index 0000000..749c182
--- /dev/null
+++ b/tests/unit/java/android/net/QosSocketInfoTest.java
@@ -0,0 +1,111 @@
+/*
+ * Copyright (C) 2022 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net;
+
+import static android.system.OsConstants.SOCK_DGRAM;
+import static android.system.OsConstants.SOCK_STREAM;
+
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+import android.os.Build;
+
+import androidx.test.filters.SmallTest;
+
+import com.android.testutils.DevSdkIgnoreRule;
+import com.android.testutils.DevSdkIgnoreRunner;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+
+import java.net.DatagramSocket;
+import java.net.InetSocketAddress;
+import java.net.ServerSocket;
+import java.net.Socket;
+
+@RunWith(DevSdkIgnoreRunner.class)
+@SmallTest
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.R)
+public class QosSocketInfoTest {
+    @Mock
+    private Network mMockNetwork = mock(Network.class);
+
+    @Test
+    public void testConstructWithSock() throws Exception {
+        ServerSocket server = new ServerSocket();
+        ServerSocket server6 = new ServerSocket();
+
+        InetSocketAddress clientAddr = new InetSocketAddress("127.0.0.1", 0);
+        InetSocketAddress serverAddr = new InetSocketAddress("127.0.0.1", 0);
+        InetSocketAddress clientAddr6 = new InetSocketAddress("::1", 0);
+        InetSocketAddress serverAddr6 = new InetSocketAddress("::1", 0);
+        server.bind(serverAddr);
+        server6.bind(serverAddr6);
+        Socket socket = new Socket(serverAddr.getAddress(), server.getLocalPort(),
+                clientAddr.getAddress(), clientAddr.getPort());
+        Socket socket6 = new Socket(serverAddr6.getAddress(), server6.getLocalPort(),
+                clientAddr6.getAddress(), clientAddr6.getPort());
+        QosSocketInfo sockInfo = new QosSocketInfo(mMockNetwork, socket);
+        QosSocketInfo sockInfo6 = new QosSocketInfo(mMockNetwork, socket6);
+        assertTrue(sockInfo.getLocalSocketAddress()
+                .equals(new InetSocketAddress(socket.getLocalAddress(), socket.getLocalPort())));
+        assertTrue(sockInfo.getRemoteSocketAddress()
+                .equals((InetSocketAddress) socket.getRemoteSocketAddress()));
+        assertEquals(SOCK_STREAM, sockInfo.getSocketType());
+        assertTrue(sockInfo6.getLocalSocketAddress()
+                .equals(new InetSocketAddress(socket6.getLocalAddress(), socket6.getLocalPort())));
+        assertTrue(sockInfo6.getRemoteSocketAddress()
+                .equals((InetSocketAddress) socket6.getRemoteSocketAddress()));
+        assertEquals(SOCK_STREAM, sockInfo6.getSocketType());
+        socket.close();
+        socket6.close();
+        server.close();
+        server6.close();
+    }
+
+    @Test
+    public void testConstructWithDatagramSock() throws Exception {
+        InetSocketAddress clientAddr = new InetSocketAddress("127.0.0.1", 0);
+        InetSocketAddress serverAddr = new InetSocketAddress("127.0.0.1", 0);
+        InetSocketAddress clientAddr6 = new InetSocketAddress("::1", 0);
+        InetSocketAddress serverAddr6 = new InetSocketAddress("::1", 0);
+        DatagramSocket socket = new DatagramSocket(null);
+        socket.setReuseAddress(true);
+        socket.bind(clientAddr);
+        socket.connect(serverAddr);
+        DatagramSocket socket6 = new DatagramSocket(null);
+        socket6.setReuseAddress(true);
+        socket6.bind(clientAddr);
+        socket6.connect(serverAddr);
+        QosSocketInfo sockInfo = new QosSocketInfo(mMockNetwork, socket);
+        QosSocketInfo sockInfo6 = new QosSocketInfo(mMockNetwork, socket6);
+        assertTrue(sockInfo.getLocalSocketAddress()
+                .equals((InetSocketAddress) socket.getLocalSocketAddress()));
+        assertTrue(sockInfo.getRemoteSocketAddress()
+                .equals((InetSocketAddress) socket.getRemoteSocketAddress()));
+        assertEquals(SOCK_DGRAM, sockInfo.getSocketType());
+        assertTrue(sockInfo6.getLocalSocketAddress()
+                .equals((InetSocketAddress) socket6.getLocalSocketAddress()));
+        assertTrue(sockInfo6.getRemoteSocketAddress()
+                .equals((InetSocketAddress) socket6.getRemoteSocketAddress()));
+        assertEquals(SOCK_DGRAM, sockInfo6.getSocketType());
+        socket.close();
+    }
+}
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index b8e90ff..34c8ff5 100644
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -137,6 +137,8 @@
 import static com.android.server.ConnectivityService.PREFERENCE_ORDER_PROFILE;
 import static com.android.server.ConnectivityService.PREFERENCE_ORDER_VPN;
 import static com.android.server.ConnectivityServiceTestUtils.transportToLegacyType;
+import static com.android.server.NetworkAgentWrapper.CallbackType.OnQosCallbackRegister;
+import static com.android.server.NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister;
 import static com.android.testutils.ConcurrentUtils.await;
 import static com.android.testutils.ConcurrentUtils.durationOf;
 import static com.android.testutils.DevSdkIgnoreRule.IgnoreAfter;
@@ -11939,16 +11941,14 @@
         mQosCallbackMockHelper.registerQosCallback(
                 mQosCallbackMockHelper.mFilter, mQosCallbackMockHelper.mCallback);
 
-        final NetworkAgentWrapper.CallbackType.OnQosCallbackRegister cbRegister1 =
-                (NetworkAgentWrapper.CallbackType.OnQosCallbackRegister)
-                        wrapper.getCallbackHistory().poll(1000, x -> true);
+        final OnQosCallbackRegister cbRegister1 =
+                (OnQosCallbackRegister) wrapper.getCallbackHistory().poll(1000, x -> true);
         assertNotNull(cbRegister1);
 
         final int registerCallbackId = cbRegister1.mQosCallbackId;
         mService.unregisterQosCallback(mQosCallbackMockHelper.mCallback);
-        final NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister cbUnregister;
-        cbUnregister = (NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister)
-                wrapper.getCallbackHistory().poll(1000, x -> true);
+        final OnQosCallbackUnregister cbUnregister =
+                (OnQosCallbackUnregister) wrapper.getCallbackHistory().poll(1000, x -> true);
         assertNotNull(cbUnregister);
         assertEquals(registerCallbackId, cbUnregister.mQosCallbackId);
         assertNull(wrapper.getCallbackHistory().poll(200, x -> true));
@@ -12027,6 +12027,86 @@
                         && session.getSessionType() == QosSession.TYPE_NR_BEARER));
     }
 
+    @Test @IgnoreUpTo(SC_V2)
+    public void testQosCallbackAvailableOnValidationError() throws Exception {
+        mQosCallbackMockHelper = new QosCallbackMockHelper();
+        final NetworkAgentWrapper wrapper = mQosCallbackMockHelper.mAgentWrapper;
+        final int sessionId = 10;
+        final int qosCallbackId = 1;
+
+        doReturn(QosCallbackException.EX_TYPE_FILTER_NONE)
+                .when(mQosCallbackMockHelper.mFilter).validate();
+        mQosCallbackMockHelper.registerQosCallback(
+                mQosCallbackMockHelper.mFilter, mQosCallbackMockHelper.mCallback);
+        OnQosCallbackRegister cbRegister1 =
+                (OnQosCallbackRegister) wrapper.getCallbackHistory().poll(1000, x -> true);
+        assertNotNull(cbRegister1);
+        final int registerCallbackId = cbRegister1.mQosCallbackId;
+
+        waitForIdle();
+
+        doReturn(QosCallbackException.EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED)
+                .when(mQosCallbackMockHelper.mFilter).validate();
+        final EpsBearerQosSessionAttributes attributes = new EpsBearerQosSessionAttributes(
+                1, 2, 3, 4, 5, new ArrayList<>());
+        mQosCallbackMockHelper.mAgentWrapper.getNetworkAgent()
+                .sendQosSessionAvailable(qosCallbackId, sessionId, attributes);
+        waitForIdle();
+
+        final NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister cbUnregister;
+        cbUnregister = (NetworkAgentWrapper.CallbackType.OnQosCallbackUnregister)
+                wrapper.getCallbackHistory().poll(1000, x -> true);
+        assertNotNull(cbUnregister);
+        assertEquals(registerCallbackId, cbUnregister.mQosCallbackId);
+        waitForIdle();
+        verify(mQosCallbackMockHelper.mCallback)
+                .onError(eq(QosCallbackException.EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED));
+    }
+
+    @Test @IgnoreUpTo(SC_V2)
+    public void testQosCallbackLostOnValidationError() throws Exception {
+        mQosCallbackMockHelper = new QosCallbackMockHelper();
+        final int sessionId = 10;
+        final int qosCallbackId = 1;
+
+        doReturn(QosCallbackException.EX_TYPE_FILTER_NONE)
+                .when(mQosCallbackMockHelper.mFilter).validate();
+        mQosCallbackMockHelper.registerQosCallback(
+                mQosCallbackMockHelper.mFilter, mQosCallbackMockHelper.mCallback);
+        waitForIdle();
+        EpsBearerQosSessionAttributes attributes =
+                sendQosSessionEvent(qosCallbackId, sessionId, true);
+        waitForIdle();
+
+        verify(mQosCallbackMockHelper.mCallback).onQosEpsBearerSessionAvailable(argThat(session ->
+                session.getSessionId() == sessionId
+                        && session.getSessionType() == QosSession.TYPE_EPS_BEARER), eq(attributes));
+
+        doReturn(QosCallbackException.EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED)
+                .when(mQosCallbackMockHelper.mFilter).validate();
+
+        sendQosSessionEvent(qosCallbackId, sessionId, false);
+        waitForIdle();
+        verify(mQosCallbackMockHelper.mCallback)
+                .onError(eq(QosCallbackException.EX_TYPE_FILTER_SOCKET_REMOTE_ADDRESS_CHANGED));
+    }
+
+    private EpsBearerQosSessionAttributes sendQosSessionEvent(
+            int qosCallbackId, int sessionId, boolean available) {
+        if (available) {
+            final EpsBearerQosSessionAttributes attributes = new EpsBearerQosSessionAttributes(
+                    1, 2, 3, 4, 5, new ArrayList<>());
+            mQosCallbackMockHelper.mAgentWrapper.getNetworkAgent()
+                    .sendQosSessionAvailable(qosCallbackId, sessionId, attributes);
+            return attributes;
+        } else {
+            mQosCallbackMockHelper.mAgentWrapper.getNetworkAgent()
+                    .sendQosSessionLost(qosCallbackId, sessionId, QosSession.TYPE_EPS_BEARER);
+            return null;
+        }
+
+    }
+
     @Test
     public void testQosCallbackTooManyRequests() throws Exception {
         mQosCallbackMockHelper = new QosCallbackMockHelper();
diff --git a/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java b/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java
index 5f9d1ff..13a85e8 100644
--- a/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java
+++ b/tests/unit/java/com/android/server/net/NetworkStatsObserversTest.java
@@ -86,11 +86,19 @@
     private static NetworkTemplate sTemplateImsi1 = buildTemplateMobileAll(IMSI_1);
     private static NetworkTemplate sTemplateImsi2 = buildTemplateMobileAll(IMSI_2);
 
+    private static final int PID_SYSTEM = 1234;
+    private static final int PID_RED = 1235;
+    private static final int PID_BLUE = 1236;
+
     private static final int UID_RED = UserHandle.PER_USER_RANGE + 1;
     private static final int UID_BLUE = UserHandle.PER_USER_RANGE + 2;
     private static final int UID_GREEN = UserHandle.PER_USER_RANGE + 3;
     private static final int UID_ANOTHER_USER = 2 * UserHandle.PER_USER_RANGE + 4;
 
+    private static final String PACKAGE_SYSTEM = "android";
+    private static final String PACKAGE_RED = "RED";
+    private static final String PACKAGE_BLUE = "BLUE";
+
     private static final long WAIT_TIMEOUT_MS = 500;
     private static final long THRESHOLD_BYTES = 2 * MB_IN_BYTES;
     private static final long BASE_BYTES = 7 * MB_IN_BYTES;
@@ -131,14 +139,15 @@
                 DataUsageRequest.REQUEST_ID_UNSET, sTemplateWifi, thresholdTooLowBytes);
 
         final DataUsageRequest requestByApp = mStatsObservers.register(mContext, inputRequest,
-                mUsageCallback, UID_RED, NetworkStatsAccess.Level.DEVICE);
+                mUsageCallback, PID_RED , UID_RED, PACKAGE_RED, NetworkStatsAccess.Level.DEVICE);
         assertTrue(requestByApp.requestId > 0);
         assertTrue(Objects.equals(sTemplateWifi, requestByApp.template));
         assertEquals(thresholdTooLowBytes, requestByApp.thresholdInBytes);
 
         // Verify the threshold requested by system uid won't be overridden.
         final DataUsageRequest requestBySystem = mStatsObservers.register(mContext, inputRequest,
-                mUsageCallback, Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+                mUsageCallback, PID_SYSTEM, Process.SYSTEM_UID, PACKAGE_SYSTEM,
+                NetworkStatsAccess.Level.DEVICE);
         assertTrue(requestBySystem.requestId > 0);
         assertTrue(Objects.equals(sTemplateWifi, requestBySystem.template));
         assertEquals(1, requestBySystem.thresholdInBytes);
@@ -151,7 +160,7 @@
                 DataUsageRequest.REQUEST_ID_UNSET, sTemplateWifi, highThresholdBytes);
 
         DataUsageRequest request = mStatsObservers.register(mContext, inputRequest, mUsageCallback,
-                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+                PID_SYSTEM, Process.SYSTEM_UID, PACKAGE_SYSTEM, NetworkStatsAccess.Level.DEVICE);
         assertTrue(request.requestId > 0);
         assertTrue(Objects.equals(sTemplateWifi, request.template));
         assertEquals(highThresholdBytes, request.thresholdInBytes);
@@ -163,13 +172,13 @@
                 DataUsageRequest.REQUEST_ID_UNSET, sTemplateWifi, THRESHOLD_BYTES);
 
         DataUsageRequest request1 = mStatsObservers.register(mContext, inputRequest, mUsageCallback,
-                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+                PID_SYSTEM, Process.SYSTEM_UID, PACKAGE_SYSTEM, NetworkStatsAccess.Level.DEVICE);
         assertTrue(request1.requestId > 0);
         assertTrue(Objects.equals(sTemplateWifi, request1.template));
         assertEquals(THRESHOLD_BYTES, request1.thresholdInBytes);
 
         DataUsageRequest request2 = mStatsObservers.register(mContext, inputRequest, mUsageCallback,
-                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+                PID_SYSTEM, Process.SYSTEM_UID, PACKAGE_SYSTEM, NetworkStatsAccess.Level.DEVICE);
         assertTrue(request2.requestId > request1.requestId);
         assertTrue(Objects.equals(sTemplateWifi, request2.template));
         assertEquals(THRESHOLD_BYTES, request2.thresholdInBytes);
@@ -189,7 +198,7 @@
                 DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES);
 
         DataUsageRequest request = mStatsObservers.register(mContext, inputRequest, mUsageCallback,
-                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+                PID_SYSTEM, Process.SYSTEM_UID, PACKAGE_SYSTEM, NetworkStatsAccess.Level.DEVICE);
         assertTrue(request.requestId > 0);
         assertTrue(Objects.equals(sTemplateImsi1, request.template));
         assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
@@ -209,7 +218,7 @@
                 DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES);
 
         DataUsageRequest request = mStatsObservers.register(mContext, inputRequest, mUsageCallback,
-                UID_RED, NetworkStatsAccess.Level.DEVICE);
+                PID_RED, UID_RED, PACKAGE_RED, NetworkStatsAccess.Level.DEVICE);
         assertTrue(request.requestId > 0);
         assertTrue(Objects.equals(sTemplateImsi1, request.template));
         assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
@@ -237,7 +246,7 @@
                 DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES);
 
         DataUsageRequest request = mStatsObservers.register(mContext, inputRequest, mUsageCallback,
-                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+                PID_SYSTEM, Process.SYSTEM_UID, PACKAGE_SYSTEM, NetworkStatsAccess.Level.DEVICE);
         assertTrue(request.requestId > 0);
         assertTrue(Objects.equals(sTemplateImsi1, request.template));
         assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
@@ -261,7 +270,7 @@
                 DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES);
 
         DataUsageRequest request = mStatsObservers.register(mContext, inputRequest, mUsageCallback,
-                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+                PID_SYSTEM, Process.SYSTEM_UID, PACKAGE_SYSTEM, NetworkStatsAccess.Level.DEVICE);
         assertTrue(request.requestId > 0);
         assertTrue(Objects.equals(sTemplateImsi1, request.template));
         assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
@@ -291,7 +300,7 @@
                 DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES);
 
         DataUsageRequest request = mStatsObservers.register(mContext, inputRequest, mUsageCallback,
-                Process.SYSTEM_UID, NetworkStatsAccess.Level.DEVICE);
+                PID_SYSTEM, Process.SYSTEM_UID, PACKAGE_SYSTEM, NetworkStatsAccess.Level.DEVICE);
         assertTrue(request.requestId > 0);
         assertTrue(Objects.equals(sTemplateImsi1, request.template));
         assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
@@ -322,7 +331,7 @@
                 DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES);
 
         DataUsageRequest request = mStatsObservers.register(mContext, inputRequest, mUsageCallback,
-                UID_RED, NetworkStatsAccess.Level.DEFAULT);
+                PID_RED, UID_RED, PACKAGE_SYSTEM , NetworkStatsAccess.Level.DEFAULT);
         assertTrue(request.requestId > 0);
         assertTrue(Objects.equals(sTemplateImsi1, request.template));
         assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
@@ -355,7 +364,7 @@
                 DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES);
 
         DataUsageRequest request = mStatsObservers.register(mContext, inputRequest, mUsageCallback,
-                UID_BLUE, NetworkStatsAccess.Level.DEFAULT);
+                PID_BLUE, UID_BLUE, PACKAGE_BLUE, NetworkStatsAccess.Level.DEFAULT);
         assertTrue(request.requestId > 0);
         assertTrue(Objects.equals(sTemplateImsi1, request.template));
         assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
@@ -387,7 +396,7 @@
                 DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES);
 
         DataUsageRequest request = mStatsObservers.register(mContext, inputRequest, mUsageCallback,
-                UID_BLUE, NetworkStatsAccess.Level.USER);
+                PID_BLUE, UID_BLUE, PACKAGE_BLUE, NetworkStatsAccess.Level.USER);
         assertTrue(request.requestId > 0);
         assertTrue(Objects.equals(sTemplateImsi1, request.template));
         assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);
@@ -420,7 +429,7 @@
                 DataUsageRequest.REQUEST_ID_UNSET, sTemplateImsi1, THRESHOLD_BYTES);
 
         DataUsageRequest request = mStatsObservers.register(mContext, inputRequest, mUsageCallback,
-                UID_RED, NetworkStatsAccess.Level.USER);
+                PID_RED, UID_RED, PACKAGE_RED, NetworkStatsAccess.Level.USER);
         assertTrue(request.requestId > 0);
         assertTrue(Objects.equals(sTemplateImsi1, request.template));
         assertEquals(THRESHOLD_BYTES, request.thresholdInBytes);