Only send callbacks when overridden

Most users of NetworkCallback do not override all methods, so callbacks
from ConnectivityService end up being unused by the app.

To save resources, check which callback methods are overridden in
ConnectivityManager, and inform ConnectivityService so it can skip
preparing and sending the corresponding binder calls when the callback
is not used by the app.

This is flagged off by default.

Test: flag enabled: atest CtsNetTestCases ConnectivityCoverageTests
Bug: 327038794
Bug: 279392981

Change-Id: Iaf490f416fb6e1437d40497fba1a40a899a9f523
diff --git a/framework/src/android/net/ConnectivityManager.java b/framework/src/android/net/ConnectivityManager.java
index 4eaf973..ffaf41f 100644
--- a/framework/src/android/net/ConnectivityManager.java
+++ b/framework/src/android/net/ConnectivityManager.java
@@ -71,23 +71,29 @@
 import android.telephony.TelephonyManager;
 import android.util.ArrayMap;
 import android.util.Log;
+import android.util.LruCache;
 import android.util.Range;
 import android.util.SparseIntArray;
 
 import com.android.internal.annotations.GuardedBy;
+import com.android.internal.annotations.VisibleForTesting;
 import com.android.modules.utils.build.SdkLevel;
 
 import libcore.net.event.NetworkEventDispatcher;
 
 import java.io.IOException;
 import java.io.UncheckedIOException;
+import java.lang.annotation.ElementType;
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+import java.lang.reflect.Method;
 import java.net.DatagramSocket;
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.Socket;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.HashMap;
 import java.util.List;
@@ -1231,6 +1237,19 @@
     @GuardedBy("mTetheringEventCallbacks")
     private TetheringManager mTetheringManager;
 
+    // Cache of the most recently used NetworkCallback classes (not instances) -> method flags.
+    // 100 is chosen kind arbitrarily as an unlikely number of different types of NetworkCallback
+    // overrides that a process may have, and should generally not be reached (for example, the
+    // system server services.jar has been observed with dexdump to have only 16 when this was
+    // added, and a very large system services app only had 18).
+    // If this number is exceeded, the code will still function correctly, but re-registering
+    // using a network callback class that was used before, but 100+ other classes have been used in
+    // the meantime, will be a bit slower (as slow as the first registration) because
+    // getDeclaredMethodsFlag must re-examine the callback class to determine what methods it
+    // overrides.
+    private static final LruCache<Class<? extends NetworkCallback>, Integer> sMethodFlagsCache =
+            new LruCache<>(100);
+
     private final Object mEnabledConnectivityManagerFeaturesLock = new Object();
     // mEnabledConnectivityManagerFeatures is lazy-loaded in this ConnectivityManager instance, but
     // fetched from ConnectivityService, where it is loaded in ConnectivityService startup, so it
@@ -3996,6 +4015,55 @@
      */
     public static class NetworkCallback {
         /**
+         * Bitmask of method flags with all flags set.
+         * @hide
+         */
+        public static final int DECLARED_METHODS_ALL = ~0;
+
+        /**
+         * Bitmask of method flags with no flag set.
+         * @hide
+         */
+        public static final int DECLARED_METHODS_NONE = 0;
+
+        // Tracks whether an instance was created via reflection without calling the constructor.
+        private final boolean mConstructorWasCalled;
+
+        /**
+         * Annotation for NetworkCallback methods to verify filtering is configured properly.
+         *
+         * This is only used in tests to ensure that tests fail when a new callback is added, or
+         * callbacks are modified, without updating
+         * {@link NetworkCallbackMethodsHolder#NETWORK_CB_METHODS} properly.
+         * @hide
+         */
+        @Retention(RetentionPolicy.RUNTIME)
+        @Target(ElementType.METHOD)
+        @VisibleForTesting
+        public @interface FilteredCallback {
+            /**
+             * The NetworkCallback.METHOD_* ID of this method.
+             */
+            int methodId();
+
+            /**
+             * The ConnectivityManager.CALLBACK_* message that this method is directly called by.
+             *
+             * If this method is not called by any message, this should be
+             * {@link #CALLBACK_TRANSITIVE_CALLS_ONLY}.
+             */
+            int calledByCallbackId();
+
+            /**
+             * If this method may call other NetworkCallback methods, an array of methods it calls.
+             *
+             * Only direct calls (not transitive calls) should be included. The IDs must be
+             * NetworkCallback.METHOD_* IDs.
+             */
+            int[] mayCall() default {};
+        }
+
+        /**
          * No flags associated with this callback.
          * @hide
          */
@@ -4058,6 +4126,7 @@
                 throw new IllegalArgumentException("Invalid flags");
             }
             mFlags = flags;
+            mConstructorWasCalled = true;
         }
 
         /**
@@ -4075,7 +4144,9 @@
          *
          * @hide
          */
+        @FilteredCallback(methodId = METHOD_ONPRECHECK, calledByCallbackId = CALLBACK_PRECHECK)
         public void onPreCheck(@NonNull Network network) {}
+        private static final int METHOD_ONPRECHECK = 1;
 
         /**
          * Called when the framework connects and has declared a new network ready for use.
@@ -4090,6 +4161,11 @@
          * @param blocked Whether access to the {@link Network} is blocked due to system policy.
          * @hide
          */
+        @FilteredCallback(methodId = METHOD_ONAVAILABLE_5ARGS,
+                calledByCallbackId = CALLBACK_AVAILABLE,
+                mayCall = { METHOD_ONAVAILABLE_4ARGS,
+                        METHOD_ONLOCALNETWORKINFOCHANGED,
+                        METHOD_ONBLOCKEDSTATUSCHANGED_INT })
         public final void onAvailable(@NonNull Network network,
                 @NonNull NetworkCapabilities networkCapabilities,
                 @NonNull LinkProperties linkProperties,
@@ -4102,6 +4178,7 @@
             if (null != localInfo) onLocalNetworkInfoChanged(network, localInfo);
             onBlockedStatusChanged(network, blocked);
         }
+        private static final int METHOD_ONAVAILABLE_5ARGS = 2;
 
         /**
          * Legacy variant of onAvailable that takes a boolean blocked reason.
@@ -4114,6 +4191,13 @@
          *
          * @hide
          */
+        @FilteredCallback(methodId = METHOD_ONAVAILABLE_4ARGS,
+                calledByCallbackId = CALLBACK_TRANSITIVE_CALLS_ONLY,
+                mayCall = { METHOD_ONAVAILABLE_1ARG,
+                        METHOD_ONNETWORKSUSPENDED,
+                        METHOD_ONCAPABILITIESCHANGED,
+                        METHOD_ONLINKPROPERTIESCHANGED
+                })
         public void onAvailable(@NonNull Network network,
                 @NonNull NetworkCapabilities networkCapabilities,
                 @NonNull LinkProperties linkProperties,
@@ -4127,6 +4211,7 @@
             onLinkPropertiesChanged(network, linkProperties);
             // No call to onBlockedStatusChanged here. That is done by the caller.
         }
+        private static final int METHOD_ONAVAILABLE_4ARGS = 3;
 
         /**
          * Called when the framework connects and has declared a new network ready for use.
@@ -4157,7 +4242,10 @@
          *
          * @param network The {@link Network} of the satisfying network.
          */
+        @FilteredCallback(methodId = METHOD_ONAVAILABLE_1ARG,
+                calledByCallbackId = CALLBACK_TRANSITIVE_CALLS_ONLY)
         public void onAvailable(@NonNull Network network) {}
+        private static final int METHOD_ONAVAILABLE_1ARG = 4;
 
         /**
          * Called when the network is about to be lost, typically because there are no outstanding
@@ -4176,7 +4264,9 @@
          *                    connected for graceful handover; note that the network may still
          *                    suffer a hard loss at any time.
          */
+        @FilteredCallback(methodId = METHOD_ONLOSING, calledByCallbackId = CALLBACK_LOSING)
         public void onLosing(@NonNull Network network, int maxMsToLive) {}
+        private static final int METHOD_ONLOSING = 5;
 
         /**
          * Called when a network disconnects or otherwise no longer satisfies this request or
@@ -4197,7 +4287,9 @@
          *
          * @param network The {@link Network} lost.
          */
+        @FilteredCallback(methodId = METHOD_ONLOST, calledByCallbackId = CALLBACK_LOST)
         public void onLost(@NonNull Network network) {}
+        private static final int METHOD_ONLOST = 6;
 
         /**
          * Called if no network is found within the timeout time specified in
@@ -4207,7 +4299,9 @@
          * {@link NetworkRequest} will have already been removed and released, as if
          * {@link #unregisterNetworkCallback(NetworkCallback)} had been called.
          */
+        @FilteredCallback(methodId = METHOD_ONUNAVAILABLE, calledByCallbackId = CALLBACK_UNAVAIL)
         public void onUnavailable() {}
+        private static final int METHOD_ONUNAVAILABLE = 7;
 
         /**
          * Called when the network corresponding to this request changes capabilities but still
@@ -4224,8 +4318,11 @@
          * @param networkCapabilities The new {@link NetworkCapabilities} for this
          *                            network.
          */
+        @FilteredCallback(methodId = METHOD_ONCAPABILITIESCHANGED,
+                calledByCallbackId = CALLBACK_CAP_CHANGED)
         public void onCapabilitiesChanged(@NonNull Network network,
                 @NonNull NetworkCapabilities networkCapabilities) {}
+        private static final int METHOD_ONCAPABILITIESCHANGED = 8;
 
         /**
          * Called when the network corresponding to this request changes {@link LinkProperties}.
@@ -4240,8 +4337,11 @@
          * @param network The {@link Network} whose link properties have changed.
          * @param linkProperties The new {@link LinkProperties} for this network.
          */
+        @FilteredCallback(methodId = METHOD_ONLINKPROPERTIESCHANGED,
+                calledByCallbackId = CALLBACK_IP_CHANGED)
         public void onLinkPropertiesChanged(@NonNull Network network,
                 @NonNull LinkProperties linkProperties) {}
+        private static final int METHOD_ONLINKPROPERTIESCHANGED = 9;
 
         /**
          * Called when there is a change in the {@link LocalNetworkInfo} for this network.
@@ -4253,8 +4353,11 @@
          * @param localNetworkInfo the new {@link LocalNetworkInfo} for this network.
          * @hide
          */
+        @FilteredCallback(methodId = METHOD_ONLOCALNETWORKINFOCHANGED,
+                calledByCallbackId = CALLBACK_LOCAL_NETWORK_INFO_CHANGED)
         public void onLocalNetworkInfoChanged(@NonNull Network network,
                 @NonNull LocalNetworkInfo localNetworkInfo) {}
+        private static final int METHOD_ONLOCALNETWORKINFOCHANGED = 10;
 
         /**
          * Called when the network the framework connected to for this request suspends data
@@ -4273,7 +4376,10 @@
          *
          * @hide
          */
+        @FilteredCallback(methodId = METHOD_ONNETWORKSUSPENDED,
+                calledByCallbackId = CALLBACK_SUSPENDED)
         public void onNetworkSuspended(@NonNull Network network) {}
+        private static final int METHOD_ONNETWORKSUSPENDED = 11;
 
         /**
          * Called when the network the framework connected to for this request
@@ -4287,7 +4393,9 @@
          *
          * @hide
          */
+        @FilteredCallback(methodId = METHOD_ONNETWORKRESUMED, calledByCallbackId = CALLBACK_RESUMED)
         public void onNetworkResumed(@NonNull Network network) {}
+        private static final int METHOD_ONNETWORKRESUMED = 12;
 
         /**
          * Called when access to the specified network is blocked or unblocked.
@@ -4300,7 +4408,10 @@
          * @param network The {@link Network} whose blocked status has changed.
          * @param blocked The blocked status of this {@link Network}.
          */
+        @FilteredCallback(methodId = METHOD_ONBLOCKEDSTATUSCHANGED_BOOL,
+                calledByCallbackId = CALLBACK_TRANSITIVE_CALLS_ONLY)
         public void onBlockedStatusChanged(@NonNull Network network, boolean blocked) {}
+        private static final int METHOD_ONBLOCKEDSTATUSCHANGED_BOOL = 13;
 
         /**
          * Called when access to the specified network is blocked or unblocked, or the reason for
@@ -4318,10 +4429,14 @@
          * @param blocked The blocked status of this {@link Network}.
          * @hide
          */
+        @FilteredCallback(methodId = METHOD_ONBLOCKEDSTATUSCHANGED_INT,
+                calledByCallbackId = CALLBACK_BLK_CHANGED,
+                mayCall = { METHOD_ONBLOCKEDSTATUSCHANGED_BOOL })
         @SystemApi(client = MODULE_LIBRARIES)
         public void onBlockedStatusChanged(@NonNull Network network, @BlockedReason int blocked) {
             onBlockedStatusChanged(network, blocked != 0);
         }
+        private static final int METHOD_ONBLOCKEDSTATUSCHANGED_INT = 14;
 
         private NetworkRequest networkRequest;
         private final int mFlags;
@@ -4349,6 +4464,7 @@
         }
     }
 
+    private static final int CALLBACK_TRANSITIVE_CALLS_ONLY     = 0;
     /** @hide */
     public static final int CALLBACK_PRECHECK                   = 1;
     /** @hide */
@@ -4374,9 +4490,11 @@
     /** @hide */
     public static final int CALLBACK_LOCAL_NETWORK_INFO_CHANGED = 12;
 
+
     /** @hide */
     public static String getCallbackName(int whichCallback) {
         switch (whichCallback) {
+            case CALLBACK_TRANSITIVE_CALLS_ONLY: return "CALLBACK_TRANSITIVE_CALLS_ONLY";
             case CALLBACK_PRECHECK:     return "CALLBACK_PRECHECK";
             case CALLBACK_AVAILABLE:    return "CALLBACK_AVAILABLE";
             case CALLBACK_LOSING:       return "CALLBACK_LOSING";
@@ -4394,6 +4512,68 @@
         }
     }
 
+    /** @hide */
+    @VisibleForTesting
+    public static class NetworkCallbackMethod {
+        @NonNull
+        public final String mName;
+        @NonNull
+        public final Class<?>[] mParameterTypes;
+        // Bitmask of CALLBACK_* that may transitively call this method.
+        public final int mCallbacksCallingThisMethod;
+
+        public NetworkCallbackMethod(@NonNull String name, @NonNull Class<?>[] parameterTypes,
+                int callbacksCallingThisMethod) {
+            mName = name;
+            mParameterTypes = parameterTypes;
+            mCallbacksCallingThisMethod = callbacksCallingThisMethod;
+        }
+    }
+
+    // Holder class for the list of NetworkCallbackMethod. This ensures the list is only created
+    // once on first usage, and not just on ConnectivityManager class initialization.
+    /** @hide */
+    @VisibleForTesting
+    public static class NetworkCallbackMethodsHolder {
+        public static final NetworkCallbackMethod[] NETWORK_CB_METHODS =
+                new NetworkCallbackMethod[] {
+                        method("onPreCheck", 1 << CALLBACK_PRECHECK, Network.class),
+                        // Note the final overload of onAvailable is not included, since it cannot
+                        // match any overridden method.
+                        method("onAvailable", 1 << CALLBACK_AVAILABLE, Network.class),
+                        method("onAvailable", 1 << CALLBACK_AVAILABLE,
+                                Network.class, NetworkCapabilities.class,
+                                LinkProperties.class, boolean.class),
+                        method("onLosing", 1 << CALLBACK_LOSING, Network.class, int.class),
+                        method("onLost", 1 << CALLBACK_LOST, Network.class),
+                        method("onUnavailable", 1 << CALLBACK_UNAVAIL),
+                        method("onCapabilitiesChanged",
+                                1 << CALLBACK_CAP_CHANGED | 1 << CALLBACK_AVAILABLE,
+                                Network.class, NetworkCapabilities.class),
+                        method("onLinkPropertiesChanged",
+                                1 << CALLBACK_IP_CHANGED | 1 << CALLBACK_AVAILABLE,
+                                Network.class, LinkProperties.class),
+                        method("onLocalNetworkInfoChanged",
+                                1 << CALLBACK_LOCAL_NETWORK_INFO_CHANGED | 1 << CALLBACK_AVAILABLE,
+                                Network.class, LocalNetworkInfo.class),
+                        method("onNetworkSuspended",
+                                1 << CALLBACK_SUSPENDED | 1 << CALLBACK_AVAILABLE, Network.class),
+                        method("onNetworkResumed",
+                                1 << CALLBACK_RESUMED, Network.class),
+                        method("onBlockedStatusChanged",
+                                1 << CALLBACK_BLK_CHANGED | 1 << CALLBACK_AVAILABLE,
+                                Network.class, boolean.class),
+                        method("onBlockedStatusChanged",
+                                1 << CALLBACK_BLK_CHANGED | 1 << CALLBACK_AVAILABLE,
+                                Network.class, int.class),
+                };
+
+        private static NetworkCallbackMethod method(
+                String name, int callbacksCallingThisMethod, Class<?>... args) {
+            return new NetworkCallbackMethod(name, args, callbacksCallingThisMethod);
+        }
+    }
+
     private static class CallbackHandler extends Handler {
         private static final String TAG = "ConnectivityManager.CallbackHandler";
         private static final boolean DBG = false;
@@ -4513,6 +4693,14 @@
         if (reqType != TRACK_DEFAULT && reqType != TRACK_SYSTEM_DEFAULT && need == null) {
             throw new IllegalArgumentException("null NetworkCapabilities");
         }
+
+
+        final boolean useDeclaredMethods = isFeatureEnabled(
+                FEATURE_USE_DECLARED_METHODS_FOR_CALLBACKS);
+        // Set all bits if the feature is disabled
+        int declaredMethodsFlag = useDeclaredMethods
+                ? tryGetDeclaredMethodsFlag(callback)
+                : NetworkCallback.DECLARED_METHODS_ALL;
         final NetworkRequest request;
         final String callingPackageName = mContext.getOpPackageName();
         try {
@@ -4529,11 +4717,12 @@
                 if (reqType == LISTEN) {
                     request = mService.listenForNetwork(
                             need, messenger, binder, callbackFlags, callingPackageName,
-                            getAttributionTag());
+                            getAttributionTag(), declaredMethodsFlag);
                 } else {
                     request = mService.requestNetwork(
                             asUid, need, reqType.ordinal(), messenger, timeoutMs, binder,
-                            legacyType, callbackFlags, callingPackageName, getAttributionTag());
+                            legacyType, callbackFlags, callingPackageName, getAttributionTag(),
+                            declaredMethodsFlag);
                 }
                 if (request != null) {
                     sCallbacks.put(request, callback);
@@ -4548,6 +4737,108 @@
         return request;
     }
 
+    private int tryGetDeclaredMethodsFlag(@NonNull NetworkCallback cb) {
+        if (!cb.mConstructorWasCalled) {
+            // Do not use the optimization if the callback was created via reflection or mocking,
+            // as for example with dexmaker-mockito-inline methods will be instrumented without
+            // using subclasses. This does not catch all cases as it is still possible to call the
+            // constructor when creating mocks, but by default constructors are not called in that
+            // case.
+            return NetworkCallback.DECLARED_METHODS_ALL;
+        }
+        try {
+            return getDeclaredMethodsFlag(cb.getClass());
+        } catch (LinkageError e) {
+            // This may happen if some methods reference inaccessible classes in their arguments
+            // (for example b/261807130).
+            Log.w(TAG, "Could not get methods from NetworkCallback class", e);
+            // Fall through
+        } catch (Throwable e) {
+            // Log.wtf would be best but this is in app process, so the TerribleFailureHandler may
+            // have unknown effects, possibly crashing the app (default behavior on eng builds or
+            // if the WTF_IS_FATAL setting is set).
+            Log.e(TAG, "Unexpected error while getting methods from NetworkCallback class", e);
+            // Fall through
+        }
+        return NetworkCallback.DECLARED_METHODS_ALL;
+    }
+
+    private static int getDeclaredMethodsFlag(@NonNull Class<? extends NetworkCallback> clazz) {
+        final Integer cachedFlags = sMethodFlagsCache.get(clazz);
+        // As this is not synchronized, it is possible that this method will calculate the
+        // flags for a given class multiple times, but that is fine. LruCache itself is thread-safe.
+        if (cachedFlags != null) {
+            return cachedFlags;
+        }
+
+        int flag = 0;
+        // This uses getMethods instead of getDeclaredMethods, to make sure that if A overrides B
+        // that overrides NetworkCallback, A.getMethods also returns methods declared by B.
+        for (Method classMethod : clazz.getMethods()) {
+            final Class<?> declaringClass = classMethod.getDeclaringClass();
+            if (declaringClass == NetworkCallback.class) {
+                // The callback is as defined by NetworkCallback and not overridden
+                continue;
+            }
+            if (declaringClass == Object.class) {
+                // Optimization: no need to try to match callbacks for methods declared by Object
+                continue;
+            }
+            flag |= getCallbackIdsCallingThisMethod(classMethod);
+        }
+
+        if (flag == 0) {
+            // dexmaker-mockito-inline (InlineDexmakerMockMaker), for example for mockito-extended,
+            // modifies bytecode of classes in-place to add hooks instead of creating subclasses,
+            // which would not be detected. When no method is found, fall back to enabling callbacks
+            // for all methods.
+            // This will not catch the case where both NetworkCallback bytecode is modified and a
+            // subclass of NetworkCallback that has some overridden methods are used. But this kind
+            // of bytecode injection is only possible in debuggable processes, with a JVMTI debug
+            // agent attached, so it should not cause real issues.
+            // There may be legitimate cases where an empty callback is filed with no method
+            // overridden, for example requestNetwork(requestForCell, new NetworkCallback()) which
+            // would ensure that one cell network stays up. But there is no way to differentiate
+            // such NetworkCallbacks from a mock that called the constructor, so this code will
+            // register the callback with DECLARED_METHODS_ALL and turn off the optimization in that
+            // case. Apps are not expected to do this often anyway since the usefulness is very
+            // limited.
+            flag = NetworkCallback.DECLARED_METHODS_ALL;
+        }
+        sMethodFlagsCache.put(clazz, flag);
+        return flag;
+    }
+
+    /**
+     * Find out which of the base methods in NetworkCallback will call this method.
+     *
+     * For example, in the case of onLinkPropertiesChanged, this will be
+     * (1 << CALLBACK_IP_CHANGED) | (1 << CALLBACK_AVAILABLE).
+     */
+    private static int getCallbackIdsCallingThisMethod(@NonNull Method method) {
+        for (NetworkCallbackMethod baseMethod : NetworkCallbackMethodsHolder.NETWORK_CB_METHODS) {
+            if (!baseMethod.mName.equals(method.getName())) {
+                continue;
+            }
+            Class<?>[] methodParams = method.getParameterTypes();
+
+            // As per JLS 8.4.8.1., a method m1 must have a subsignature of method m2 to override
+            // it. And as per JLS 8.4.2, this means the erasure of the signature of m2 must be the
+            // same as the signature of m1. Since type erasure is done at compile time, with
+            // reflection the erased types are already observed, so the (erased) parameter types
+            // must be equal.
+            // So for example a method that is identical to a NetworkCallback method, except with
+            // one parameter being a subclass of the parameter in the original method, will never
+            // be called since it is not an override (the erasure of the arguments are not the same)
+            // Therefore, the method is an override only if methodParams is exactly equal to
+            // the base method's parameter types.
+            if (Arrays.equals(baseMethod.mParameterTypes, methodParams)) {
+                return baseMethod.mCallbacksCallingThisMethod;
+            }
+        }
+        return 0;
+    }
+
     private boolean isFeatureEnabled(@ConnectivityManagerFeature long connectivityManagerFeature) {
         synchronized (mEnabledConnectivityManagerFeaturesLock) {
             if (mEnabledConnectivityManagerFeatures == null) {
diff --git a/framework/src/android/net/IConnectivityManager.aidl b/framework/src/android/net/IConnectivityManager.aidl
index f9de8ed..988cc92 100644
--- a/framework/src/android/net/IConnectivityManager.aidl
+++ b/framework/src/android/net/IConnectivityManager.aidl
@@ -153,7 +153,8 @@
 
     NetworkRequest requestNetwork(int uid, in NetworkCapabilities networkCapabilities, int reqType,
             in Messenger messenger, int timeoutSec, in IBinder binder, int legacy,
-            int callbackFlags, String callingPackageName, String callingAttributionTag);
+            int callbackFlags, String callingPackageName, String callingAttributionTag,
+            int declaredMethodsFlag);
 
     NetworkRequest pendingRequestForNetwork(in NetworkCapabilities networkCapabilities,
             in PendingIntent operation, String callingPackageName, String callingAttributionTag);
@@ -162,7 +163,7 @@
 
     NetworkRequest listenForNetwork(in NetworkCapabilities networkCapabilities,
             in Messenger messenger, in IBinder binder, int callbackFlags, String callingPackageName,
-            String callingAttributionTag);
+            String callingAttributionTag, int declaredMethodsFlag);
 
     void pendingListenForNetwork(in NetworkCapabilities networkCapabilities,
             in PendingIntent operation, String callingPackageName,
diff --git a/service/src/com/android/server/ConnectivityService.java b/service/src/com/android/server/ConnectivityService.java
index d3b7f79..efd7b30 100755
--- a/service/src/com/android/server/ConnectivityService.java
+++ b/service/src/com/android/server/ConnectivityService.java
@@ -46,6 +46,7 @@
 import static android.net.ConnectivityManager.FIREWALL_RULE_ALLOW;
 import static android.net.ConnectivityManager.FIREWALL_RULE_DEFAULT;
 import static android.net.ConnectivityManager.FIREWALL_RULE_DENY;
+import static android.net.ConnectivityManager.NetworkCallback.DECLARED_METHODS_NONE;
 import static android.net.ConnectivityManager.TYPE_BLUETOOTH;
 import static android.net.ConnectivityManager.TYPE_ETHERNET;
 import static android.net.ConnectivityManager.TYPE_MOBILE;
@@ -65,6 +66,7 @@
 import static android.net.ConnectivityManager.TYPE_WIFI_P2P;
 import static android.net.ConnectivityManager.getNetworkTypeName;
 import static android.net.ConnectivityManager.isNetworkTypeValid;
+import static android.net.ConnectivityManager.NetworkCallback.DECLARED_METHODS_ALL;
 import static android.net.ConnectivitySettingsManager.PRIVATE_DNS_MODE_OPPORTUNISTIC;
 import static android.net.INetd.PERMISSION_INTERNET;
 import static android.net.INetworkMonitor.NETWORK_VALIDATION_PROBE_PRIVDNS;
@@ -126,8 +128,6 @@
 import static com.android.server.connectivity.ConnectivityFlags.REQUEST_RESTRICTED_WIFI;
 import static com.android.server.connectivity.ConnectivityFlags.INGRESS_TO_VPN_ADDRESS_FILTERING;
 
-import static java.util.Map.Entry;
-
 import android.Manifest;
 import android.annotation.CheckResult;
 import android.annotation.NonNull;
@@ -1779,7 +1779,7 @@
         mDefaultRequest = new NetworkRequestInfo(
                 Process.myUid(), defaultInternetRequest, null,
                 null /* binder */, NetworkCallback.FLAG_INCLUDE_LOCATION_INFO,
-                null /* attributionTags */);
+                null /* attributionTags */, DECLARED_METHODS_NONE);
         mNetworkRequests.put(defaultInternetRequest, mDefaultRequest);
         mDefaultNetworkRequests.add(mDefaultRequest);
         mNetworkRequestInfoLogs.log("REGISTER " + mDefaultRequest);
@@ -2166,7 +2166,7 @@
             handleRegisterNetworkRequest(new NetworkRequestInfo(
                     Process.myUid(), networkRequest, null /* messenger */, null /* binder */,
                     NetworkCallback.FLAG_INCLUDE_LOCATION_INFO,
-                    null /* attributionTags */));
+                    null /* attributionTags */, DECLARED_METHODS_NONE));
         } else {
             handleReleaseNetworkRequest(networkRequest, Process.SYSTEM_UID,
                     /* callOnUnavailable */ false);
@@ -7531,6 +7531,8 @@
         // Preference order of this request.
         final int mPreferenceOrder;
 
+        final int mDeclaredMethodsFlags;
+
         // In order to preserve the mapping of NetworkRequest-to-callback when apps register
         // callbacks using a returned NetworkRequest, the original NetworkRequest needs to be
         // maintained for keying off of. This is only a concern when the original nri
@@ -7584,21 +7586,22 @@
             mCallbackFlags = NetworkCallback.FLAG_NONE;
             mCallingAttributionTag = callingAttributionTag;
             mPreferenceOrder = preferenceOrder;
+            mDeclaredMethodsFlags = DECLARED_METHODS_NONE;
         }
 
         NetworkRequestInfo(int asUid, @NonNull final NetworkRequest r, @Nullable final Messenger m,
                 @Nullable final IBinder binder,
                 @NetworkCallback.Flag int callbackFlags,
-                @Nullable String callingAttributionTag) {
+                @Nullable String callingAttributionTag, int declaredMethodsFlags) {
             this(asUid, Collections.singletonList(r), r, m, binder, callbackFlags,
-                    callingAttributionTag);
+                    callingAttributionTag, declaredMethodsFlags);
         }
 
         NetworkRequestInfo(int asUid, @NonNull final List<NetworkRequest> r,
                 @NonNull final NetworkRequest requestForCallback, @Nullable final Messenger m,
                 @Nullable final IBinder binder,
                 @NetworkCallback.Flag int callbackFlags,
-                @Nullable String callingAttributionTag) {
+                @Nullable String callingAttributionTag, int declaredMethodsFlags) {
             super();
             ensureAllNetworkRequestsHaveType(r);
             mRequests = initializeRequests(r);
@@ -7613,6 +7616,7 @@
             mCallbackFlags = callbackFlags;
             mCallingAttributionTag = callingAttributionTag;
             mPreferenceOrder = PREFERENCE_ORDER_INVALID;
+            mDeclaredMethodsFlags = declaredMethodsFlags;
             linkDeathRecipient();
         }
 
@@ -7653,6 +7657,7 @@
             mCallingAttributionTag = nri.mCallingAttributionTag;
             mUidTrackedForBlockedStatus = nri.mUidTrackedForBlockedStatus;
             mPreferenceOrder = PREFERENCE_ORDER_INVALID;
+            mDeclaredMethodsFlags = nri.mDeclaredMethodsFlags;
             linkDeathRecipient();
         }
 
@@ -7740,7 +7745,8 @@
                     + (mPendingIntent == null ? "" : " to trigger " + mPendingIntent)
                     + " callback flags: " + mCallbackFlags
                     + " order: " + mPreferenceOrder
-                    + " isUidTracked: " + mUidTrackedForBlockedStatus;
+                    + " isUidTracked: " + mUidTrackedForBlockedStatus
+                    + " declaredMethods: 0x" + Integer.toHexString(mDeclaredMethodsFlags);
         }
     }
 
@@ -7878,7 +7884,21 @@
     public NetworkRequest requestNetwork(int asUid, NetworkCapabilities networkCapabilities,
             int reqTypeInt, Messenger messenger, int timeoutMs, final IBinder binder,
             int legacyType, int callbackFlags, @NonNull String callingPackageName,
-            @Nullable String callingAttributionTag) {
+            @Nullable String callingAttributionTag, int declaredMethodsFlag) {
+        if (declaredMethodsFlag == 0) {
+            // This could happen if raw binder calls are used to call the previous overload of
+            // requestNetwork, as missing int arguments in a binder call end up as 0
+            // (Parcel.readInt returns 0 at the end of a parcel). Such raw calls this would be
+            // really unexpected bad behavior from the caller though.
+            // TODO: remove after verifying this does not happen. This could allow enabling the
+            // optimization for callbacks that do not override any method (right now they use
+            // DECLARED_METHODS_ALL), if it is OK to break NetworkCallbacks created using
+            // dexmaker-mockito-inline and either spy() or MockSettings.useConstructor (see
+            // comment in ConnectivityManager which sets the flag to DECLARED_METHODS_ALL).
+            Log.wtf(TAG, "requestNetwork called without declaredMethodsFlag from "
+                    + callingPackageName);
+            declaredMethodsFlag = DECLARED_METHODS_ALL;
+        }
         if (legacyType != TYPE_NONE && !hasNetworkStackPermission()) {
             if (isTargetSdkAtleast(Build.VERSION_CODES.M, mDeps.getCallingUid(),
                     callingPackageName)) {
@@ -7970,7 +7990,7 @@
                 nextNetworkRequestId(), reqType);
         final NetworkRequestInfo nri = getNriToRegister(
                 asUid, networkRequest, messenger, binder, callbackFlags,
-                callingAttributionTag);
+                callingAttributionTag, declaredMethodsFlag);
         if (DBG) log("requestNetwork for " + nri);
         trackUidAndRegisterNetworkRequest(EVENT_REGISTER_NETWORK_REQUEST, nri);
         if (timeoutMs > 0) {
@@ -7996,7 +8016,7 @@
     private NetworkRequestInfo getNriToRegister(final int asUid, @NonNull final NetworkRequest nr,
             @Nullable final Messenger msgr, @Nullable final IBinder binder,
             @NetworkCallback.Flag int callbackFlags,
-            @Nullable String callingAttributionTag) {
+            @Nullable String callingAttributionTag, int declaredMethodsFlags) {
         final List<NetworkRequest> requests;
         if (NetworkRequest.Type.TRACK_DEFAULT == nr.type) {
             requests = copyDefaultNetworkRequestsForUid(
@@ -8005,7 +8025,8 @@
             requests = Collections.singletonList(nr);
         }
         return new NetworkRequestInfo(
-                asUid, requests, nr, msgr, binder, callbackFlags, callingAttributionTag);
+                asUid, requests, nr, msgr, binder, callbackFlags, callingAttributionTag,
+                declaredMethodsFlags);
     }
 
     private boolean shouldCheckCapabilitiesDeclaration(
@@ -8295,7 +8316,13 @@
     public NetworkRequest listenForNetwork(NetworkCapabilities networkCapabilities,
             Messenger messenger, IBinder binder,
             @NetworkCallback.Flag int callbackFlags,
-            @NonNull String callingPackageName, @NonNull String callingAttributionTag) {
+            @NonNull String callingPackageName, @NonNull String callingAttributionTag,
+            int declaredMethodsFlag) {
+        if (declaredMethodsFlag == 0) {
+            Log.wtf(TAG, "listenForNetwork called without declaredMethodsFlag from "
+                    + callingPackageName);
+            declaredMethodsFlag = DECLARED_METHODS_ALL;
+        }
         final int callingUid = mDeps.getCallingUid();
         if (!hasWifiNetworkListenPermission(networkCapabilities)) {
             enforceAccessPermission();
@@ -8317,7 +8344,7 @@
                 NetworkRequest.Type.LISTEN);
         NetworkRequestInfo nri =
                 new NetworkRequestInfo(callingUid, networkRequest, messenger, binder, callbackFlags,
-                        callingAttributionTag);
+                        callingAttributionTag, declaredMethodsFlag);
         if (VDBG) log("listenForNetwork for " + nri);
 
         trackUidAndRegisterNetworkRequest(EVENT_REGISTER_NETWORK_LISTENER, nri);
@@ -9759,7 +9786,8 @@
             configBuilder.setUpstreamSelector(nr);
             final NetworkRequestInfo nri = new NetworkRequestInfo(
                     nai.creatorUid, nr, null /* messenger */, null /* binder */,
-                    0 /* callbackFlags */, null /* attributionTag */);
+                    0 /* callbackFlags */, null /* attributionTag */,
+                    DECLARED_METHODS_NONE);
             if (null != oldSatisfier) {
                 // Set the old satisfier in the new NRI so that the rematch will see any changes
                 nri.setSatisfier(oldSatisfier, nr);
@@ -10130,6 +10158,11 @@
             // are Type.LISTEN, but should not have NetworkCallbacks invoked.
             return;
         }
+        if (mUseDeclaredMethodsForCallbacksEnabled
+                && (nri.mDeclaredMethodsFlags & (1 << notificationType)) == 0) {
+            // No need to send the notification as the recipient method is not overridden
+            return;
+        }
         final Bundle bundle = new Bundle();
         // TODO b/177608132: make sure callbacks are indexed by NRIs and not NetworkRequest objects.
         // TODO: check if defensive copies of data is needed.
diff --git a/tests/unit/java/android/net/ConnectivityManagerTest.java b/tests/unit/java/android/net/ConnectivityManagerTest.java
index b71a46f..9a77c89 100644
--- a/tests/unit/java/android/net/ConnectivityManagerTest.java
+++ b/tests/unit/java/android/net/ConnectivityManagerTest.java
@@ -52,6 +52,7 @@
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.anyInt;
 import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.reset;
@@ -72,6 +73,7 @@
 import android.os.Messenger;
 import android.os.Process;
 
+import androidx.annotation.NonNull;
 import androidx.test.filters.SmallTest;
 
 import com.android.internal.util.test.BroadcastInterceptingContext;
@@ -83,6 +85,7 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
+import org.mockito.InOrder;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
@@ -240,7 +243,7 @@
 
         // register callback
         when(mService.requestNetwork(anyInt(), any(), anyInt(), captor.capture(), anyInt(), any(),
-                anyInt(), anyInt(), any(), nullable(String.class))).thenReturn(request);
+                anyInt(), anyInt(), any(), nullable(String.class), anyInt())).thenReturn(request);
         manager.requestNetwork(request, callback, handler);
 
         // callback triggers
@@ -269,7 +272,7 @@
 
         // register callback
         when(mService.requestNetwork(anyInt(), any(), anyInt(), captor.capture(), anyInt(), any(),
-                anyInt(), anyInt(), any(), nullable(String.class))).thenReturn(req1);
+                anyInt(), anyInt(), any(), nullable(String.class), anyInt())).thenReturn(req1);
         manager.requestNetwork(req1, callback, handler);
 
         // callback triggers
@@ -287,7 +290,7 @@
 
         // callback can be registered again
         when(mService.requestNetwork(anyInt(), any(), anyInt(), captor.capture(), anyInt(), any(),
-                anyInt(), anyInt(), any(), nullable(String.class))).thenReturn(req2);
+                anyInt(), anyInt(), any(), nullable(String.class), anyInt())).thenReturn(req2);
         manager.requestNetwork(req2, callback, handler);
 
         // callback triggers
@@ -311,7 +314,7 @@
 
         when(mCtx.getApplicationInfo()).thenReturn(info);
         when(mService.requestNetwork(anyInt(), any(), anyInt(), any(), anyInt(), any(), anyInt(),
-                anyInt(), any(), nullable(String.class))).thenReturn(request);
+                anyInt(), any(), nullable(String.class), anyInt())).thenReturn(request);
 
         Handler handler = new Handler(Looper.getMainLooper());
         manager.requestNetwork(request, callback, handler);
@@ -403,15 +406,15 @@
         manager.requestNetwork(request, callback);
         verify(mService).requestNetwork(eq(Process.INVALID_UID), eq(request.networkCapabilities),
                 eq(REQUEST.ordinal()), any(), anyInt(), any(), eq(TYPE_NONE), anyInt(),
-                eq(testPkgName), eq(testAttributionTag));
+                eq(testPkgName), eq(testAttributionTag), anyInt());
         reset(mService);
 
         // Verify that register network callback does not calls requestNetwork at all.
         manager.registerNetworkCallback(request, callback);
         verify(mService, never()).requestNetwork(anyInt(), any(), anyInt(), any(), anyInt(), any(),
-                anyInt(), anyInt(), any(), any());
+                anyInt(), anyInt(), any(), any(), anyInt());
         verify(mService).listenForNetwork(eq(request.networkCapabilities), any(), any(), anyInt(),
-                eq(testPkgName), eq(testAttributionTag));
+                eq(testPkgName), eq(testAttributionTag), anyInt());
         reset(mService);
 
         Handler handler = new Handler(ConnectivityThread.getInstanceLooper());
@@ -419,24 +422,24 @@
         manager.registerDefaultNetworkCallback(callback);
         verify(mService).requestNetwork(eq(Process.INVALID_UID), eq(null),
                 eq(TRACK_DEFAULT.ordinal()), any(), anyInt(), any(), eq(TYPE_NONE), anyInt(),
-                eq(testPkgName), eq(testAttributionTag));
+                eq(testPkgName), eq(testAttributionTag), anyInt());
         reset(mService);
 
         manager.registerDefaultNetworkCallbackForUid(42, callback, handler);
         verify(mService).requestNetwork(eq(42), eq(null),
                 eq(TRACK_DEFAULT.ordinal()), any(), anyInt(), any(), eq(TYPE_NONE), anyInt(),
-                eq(testPkgName), eq(testAttributionTag));
+                eq(testPkgName), eq(testAttributionTag), anyInt());
 
         manager.requestBackgroundNetwork(request, callback, handler);
         verify(mService).requestNetwork(eq(Process.INVALID_UID), eq(request.networkCapabilities),
                 eq(BACKGROUND_REQUEST.ordinal()), any(), anyInt(), any(), eq(TYPE_NONE), anyInt(),
-                eq(testPkgName), eq(testAttributionTag));
+                eq(testPkgName), eq(testAttributionTag), anyInt());
         reset(mService);
 
         manager.registerSystemDefaultNetworkCallback(callback, handler);
         verify(mService).requestNetwork(eq(Process.INVALID_UID), eq(null),
                 eq(TRACK_SYSTEM_DEFAULT.ordinal()), any(), anyInt(), any(), eq(TYPE_NONE), anyInt(),
-                eq(testPkgName), eq(testAttributionTag));
+                eq(testPkgName), eq(testAttributionTag), anyInt());
         reset(mService);
     }
 
@@ -516,16 +519,154 @@
                     + " attempts", ref.get());
     }
 
-    private <T> void mockService(Class<T> clazz, String name, T service) {
-        doReturn(service).when(mCtx).getSystemService(name);
-        doReturn(name).when(mCtx).getSystemServiceName(clazz);
+    @Test
+    public void testDeclaredMethodsFlag_requestWithMixedMethods_RegistrationFlagsMatch()
+            throws Exception {
+        doReturn(ConnectivityManager.FEATURE_USE_DECLARED_METHODS_FOR_CALLBACKS)
+                .when(mService).getEnabledConnectivityManagerFeatures();
+        final ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
 
-        // If the test suite uses the inline mock maker library, such as for coverage tests,
-        // then the final version of getSystemService must also be mocked, as the real
-        // method will not be called by the test and null object is returned since no mock.
-        // Otherwise, mocking a final method will fail the test.
-        if (mCtx.getSystemService(clazz) == null) {
-            doReturn(service).when(mCtx).getSystemService(clazz);
-        }
+        final NetworkRequest request = new NetworkRequest.Builder().build();
+        final NetworkCallback callback1 = new ConnectivityManager.NetworkCallback() {
+            @Override
+            public void onPreCheck(@NonNull Network network) {}
+            @Override
+            public void onAvailable(@NonNull Network network) {}
+            @Override
+            public void onLost(@NonNull Network network) {}
+            @Override
+            public void onCapabilitiesChanged(@NonNull Network network,
+                    @NonNull NetworkCapabilities networkCapabilities) {}
+            @Override
+            public void onLocalNetworkInfoChanged(@NonNull Network network,
+                    @NonNull LocalNetworkInfo localNetworkInfo) {}
+            @Override
+            public void onNetworkResumed(@NonNull Network network) {}
+            @Override
+            public void onBlockedStatusChanged(@NonNull Network network, int blocked) {}
+        };
+        manager.requestNetwork(request, callback1);
+
+        final InOrder inOrder = inOrder(mService);
+        inOrder.verify(mService).requestNetwork(
+                anyInt(), any(), anyInt(), any(), anyInt(), any(), anyInt(), anyInt(), any(), any(),
+                eq(1 << ConnectivityManager.CALLBACK_PRECHECK
+                        | 1 << ConnectivityManager.CALLBACK_AVAILABLE
+                        | 1 << ConnectivityManager.CALLBACK_LOST
+                        | 1 << ConnectivityManager.CALLBACK_CAP_CHANGED
+                        | 1 << ConnectivityManager.CALLBACK_LOCAL_NETWORK_INFO_CHANGED
+                        | 1 << ConnectivityManager.CALLBACK_RESUMED
+                        | 1 << ConnectivityManager.CALLBACK_BLK_CHANGED));
+    }
+
+    @Test
+    public void testDeclaredMethodsFlag_listenWithMixedMethods_RegistrationFlagsMatch()
+            throws Exception {
+        final NetworkRequest request = new NetworkRequest.Builder().build();
+        doReturn(ConnectivityManager.FEATURE_USE_DECLARED_METHODS_FOR_CALLBACKS)
+                .when(mService).getEnabledConnectivityManagerFeatures();
+        final ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
+
+        final NetworkCallback callback2 = new ConnectivityManager.NetworkCallback() {
+            @Override
+            public void onLosing(@NonNull Network network, int maxMsToLive) {}
+            @Override
+            public void onUnavailable() {}
+            @Override
+            public void onLinkPropertiesChanged(@NonNull Network network,
+                    @NonNull LinkProperties linkProperties) {}
+            @Override
+            public void onNetworkSuspended(@NonNull Network network) {}
+        };
+        manager.registerNetworkCallback(request, callback2);
+        // Call a second time with the same callback to exercise caching
+        manager.registerNetworkCallback(request, callback2);
+
+        verify(mService, times(2)).listenForNetwork(
+                any(), any(), any(), anyInt(), any(), any(),
+                eq(1 << ConnectivityManager.CALLBACK_LOSING
+                        // AVAILABLE calls IP_CHANGED and SUSPENDED so it gets added
+                        | 1 << ConnectivityManager.CALLBACK_AVAILABLE
+                        | 1 << ConnectivityManager.CALLBACK_UNAVAIL
+                        | 1 << ConnectivityManager.CALLBACK_IP_CHANGED
+                        | 1 << ConnectivityManager.CALLBACK_SUSPENDED));
+    }
+
+    @Test
+    public void testDeclaredMethodsFlag_requestWithHiddenAvailableCallback_RegistrationFlagsMatch()
+            throws Exception {
+        doReturn(ConnectivityManager.FEATURE_USE_DECLARED_METHODS_FOR_CALLBACKS)
+                .when(mService).getEnabledConnectivityManagerFeatures();
+        final ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
+
+        final NetworkCallback hiddenOnAvailableCb = new ConnectivityManager.NetworkCallback() {
+            // This overload is @hide but might still be used by (bad) apps
+            @Override
+            public void onAvailable(@NonNull Network network,
+                    @NonNull NetworkCapabilities networkCapabilities,
+                    @NonNull LinkProperties linkProperties, boolean blocked) {}
+        };
+        manager.registerDefaultNetworkCallback(hiddenOnAvailableCb);
+
+        verify(mService).requestNetwork(
+                anyInt(), any(), anyInt(), any(), anyInt(), any(), anyInt(), anyInt(), any(), any(),
+                eq(1 << ConnectivityManager.CALLBACK_AVAILABLE));
+    }
+
+    public static class NetworkCallbackWithOnLostOnly extends NetworkCallback {
+        @Override
+        public void onLost(@NonNull Network network) {}
+    }
+
+    @Test
+    public void testDeclaredMethodsFlag_requestWithoutAvailableCallback_RegistrationFlagsMatch()
+            throws Exception {
+        doReturn(ConnectivityManager.FEATURE_USE_DECLARED_METHODS_FOR_CALLBACKS)
+                .when(mService).getEnabledConnectivityManagerFeatures();
+        final ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
+        final Handler handler = new Handler(Looper.getMainLooper());
+
+        final NetworkCallback noOnAvailableCb = new NetworkCallbackWithOnLostOnly();
+        manager.registerSystemDefaultNetworkCallback(noOnAvailableCb, handler);
+
+        verify(mService).requestNetwork(
+                anyInt(), any(), anyInt(), any(), anyInt(), any(), anyInt(), anyInt(), any(), any(),
+                eq(1 << ConnectivityManager.CALLBACK_LOST));
+    }
+
+    @Test
+    public void testDeclaredMethodsFlag_listenWithMock_OptimizationDisabled()
+            throws Exception {
+        doReturn(ConnectivityManager.FEATURE_USE_DECLARED_METHODS_FOR_CALLBACKS)
+                .when(mService).getEnabledConnectivityManagerFeatures();
+        final ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
+        final Handler handler = new Handler(Looper.getMainLooper());
+
+        final NetworkRequest request = new NetworkRequest.Builder().build();
+        manager.registerNetworkCallback(request, mock(NetworkCallbackWithOnLostOnly.class),
+                handler);
+
+        verify(mService).listenForNetwork(
+                any(), any(), any(), anyInt(), any(), any(),
+                // Mock that does not call the constructor -> do not use the optimization
+                eq(~0));
+    }
+
+    @Test
+    public void testDeclaredMethodsFlag_requestWitNoCallback_OptimizationDisabled()
+            throws Exception {
+        doReturn(ConnectivityManager.FEATURE_USE_DECLARED_METHODS_FOR_CALLBACKS)
+                .when(mService).getEnabledConnectivityManagerFeatures();
+        final ConnectivityManager manager = new ConnectivityManager(mCtx, mService);
+        final Handler handler = new Handler(Looper.getMainLooper());
+
+        final NetworkRequest request = new NetworkRequest.Builder().build();
+        final NetworkCallback noCallbackAtAll = new ConnectivityManager.NetworkCallback() {};
+        manager.requestBackgroundNetwork(request, noCallbackAtAll, handler);
+
+        verify(mService).requestNetwork(
+                anyInt(), any(), anyInt(), any(), anyInt(), any(), anyInt(), anyInt(), any(), any(),
+                // No callbacks overridden -> do not use the optimization
+                eq(~0));
     }
 }
diff --git a/tests/unit/java/android/net/NetworkCallbackFlagsTest.kt b/tests/unit/java/android/net/NetworkCallbackFlagsTest.kt
new file mode 100644
index 0000000..af06a64
--- /dev/null
+++ b/tests/unit/java/android/net/NetworkCallbackFlagsTest.kt
@@ -0,0 +1,212 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.net
+
+import android.net.ConnectivityManager.NetworkCallback
+import android.net.ConnectivityManager.NetworkCallbackMethodsHolder
+import android.os.Build
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import java.lang.reflect.Method
+import java.lang.reflect.Modifier
+import kotlin.test.assertEquals
+import kotlin.test.assertNotNull
+import kotlin.test.fail
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.Mockito.any
+import org.mockito.Mockito.doCallRealMethod
+import org.mockito.Mockito.mock
+import org.mockito.Mockito.mockingDetails
+
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.R)
+@RunWith(DevSdkIgnoreRunner::class)
+class NetworkCallbackFlagsTest {
+
+    // To avoid developers forgetting to update NETWORK_CB_METHODS when modifying NetworkCallbacks,
+    // or using wrong values, calculate it from annotations here and verify that it matches.
+    // This avoids the runtime cost of reflection, but still ensures that the list is correct.
+    @Test
+    fun testNetworkCallbackMethods_calculateFromAnnotations_matchesHardcodedList() {
+        val calculatedMethods = getNetworkCallbackMethodsFromAnnotations()
+        assertEquals(
+            calculatedMethods.toSet(),
+            NetworkCallbackMethodsHolder.NETWORK_CB_METHODS.map {
+                NetworkCallbackMethodWithEquals(
+                    it.mName,
+                    it.mParameterTypes.toList(),
+                    callbacksCallingThisMethod = it.mCallbacksCallingThisMethod
+                )
+            }.toSet()
+        )
+    }
+
+    data class NetworkCallbackMethodWithEquals(
+        val name: String,
+        val parameterTypes: List<Class<*>>,
+        val callbacksCallingThisMethod: Int
+    )
+
+    data class NetworkCallbackMethodBuilder(
+        val name: String,
+        val parameterTypes: List<Class<*>>,
+        val isFinal: Boolean,
+        val methodId: Int,
+        val mayCall: Set<Int>?,
+        var callbacksCallingThisMethod: Int
+    ) {
+        fun build() = NetworkCallbackMethodWithEquals(
+            name,
+            parameterTypes,
+            callbacksCallingThisMethod
+        )
+    }
+
+    /**
+     * Build [NetworkCallbackMethodsHolder.NETWORK_CB_METHODS] from [NetworkCallback] annotations.
+     */
+    private fun getNetworkCallbackMethodsFromAnnotations(): List<NetworkCallbackMethodWithEquals> {
+        val parsedMethods = mutableListOf<NetworkCallbackMethodBuilder>()
+        val methods = NetworkCallback::class.java.declaredMethods
+        methods.forEach { method ->
+            val cb = method.getAnnotation(
+                NetworkCallback.FilteredCallback::class.java
+            ) ?: return@forEach
+            val callbacksCallingThisMethod = if (cb.calledByCallbackId == 0) {
+                0
+            } else {
+                1 shl cb.calledByCallbackId
+            }
+            parsedMethods.add(
+                NetworkCallbackMethodBuilder(
+                    method.name,
+                    method.parameterTypes.toList(),
+                    Modifier.isFinal(method.modifiers),
+                    cb.methodId,
+                    cb.mayCall.toSet(),
+                    callbacksCallingThisMethod
+                )
+            )
+        }
+
+        // Propagate callbacksCallingThisMethod for transitive calls
+        do {
+            var hadChange = false
+            parsedMethods.forEach { caller ->
+                parsedMethods.forEach { callee ->
+                    if (caller.mayCall?.contains(callee.methodId) == true) {
+                        // Callbacks that call the caller also cause calls to the callee. So
+                        // callbacksCallingThisMethod for the callee should include
+                        // callbacksCallingThisMethod from the caller.
+                        val newValue =
+                            caller.callbacksCallingThisMethod or callee.callbacksCallingThisMethod
+                        hadChange = hadChange || callee.callbacksCallingThisMethod != newValue
+                        callee.callbacksCallingThisMethod = newValue
+                    }
+                }
+            }
+        } while (hadChange)
+
+        // Final methods may affect the flags for transitive calls, but cannot be overridden, so do
+        // not need to be in the list (no overridden method in NetworkCallback will match them).
+        return parsedMethods.filter { !it.isFinal }.map { it.build() }
+    }
+
+    @Test
+    fun testMethodsAreAnnotated() {
+        val annotations = NetworkCallback::class.java.declaredMethods.mapNotNull { method ->
+            if (!Modifier.isPublic(method.modifiers) && !Modifier.isProtected(method.modifiers)) {
+                return@mapNotNull null
+            }
+            val annotation = method.getAnnotation(NetworkCallback.FilteredCallback::class.java)
+            assertNotNull(annotation, "$method is missing the @FilteredCallback annotation")
+            return@mapNotNull annotation
+        }
+
+        annotations.groupingBy { it.methodId }.eachCount().forEach { (methodId, cnt) ->
+            assertEquals(1, cnt, "Method ID $methodId is used more than once in @FilteredCallback")
+        }
+    }
+
+    @Test
+    fun testObviousCalleesAreInAnnotation() {
+        NetworkCallback::class.java.declaredMethods.forEach { method ->
+            val annotation = method.getAnnotation(NetworkCallback.FilteredCallback::class.java)
+                ?: return@forEach
+            val missingFlags = getObviousCallees(method).toMutableSet().apply {
+                removeAll(annotation.mayCall.toSet())
+            }
+            val msg = "@FilteredCallback on $method is missing flags " +
+                    "$missingFlags in mayCall. There may be other " +
+                    "calls that are not detected if they are done conditionally."
+            assertEquals(emptySet(), missingFlags, msg)
+        }
+    }
+
+    /**
+     * Invoke the specified NetworkCallback method with mock arguments, return a set of transitively
+     * called methods.
+     *
+     * This provides an idea of which methods are transitively called by the specified method. It's
+     * not perfect as some callees could be called or not depending on the exact values of the mock
+     * arguments that are passed in (for example, onAvailable calls onNetworkSuspended only if the
+     * capabilities lack the NOT_SUSPENDED capability), but it should catch obvious forgotten calls.
+     */
+    private fun getObviousCallees(method: Method): Set<Int> {
+        // Create a mock NetworkCallback that mocks all methods except the one specified by the
+        // caller.
+        val mockCallback = mock(NetworkCallback::class.java)
+
+        if (!Modifier.isFinal(method.modifiers) ||
+            // The mock class will be NetworkCallback (not a subclass) if using mockito-inline,
+            // which mocks final methods too
+            mockCallback.javaClass == NetworkCallback::class.java) {
+            doCallRealMethod().`when`(mockCallback).let { mockObj ->
+                val anyArgs = method.parameterTypes.map { any(it) }
+                method.invoke(mockObj, *anyArgs.toTypedArray())
+            }
+        }
+
+        // Invoke the target method with mock parameters
+        val mockParameters = method.parameterTypes.map { getMockFor(method, it) }
+        method.invoke(mockCallback, *mockParameters.toTypedArray())
+
+        // Aggregate callees
+        val mockingDetails = mockingDetails(mockCallback)
+        return mockingDetails.invocations.mapNotNull { inv ->
+            if (inv.method == method) {
+                null
+            } else {
+                inv.method.getAnnotation(NetworkCallback.FilteredCallback::class.java)?.methodId
+            }
+        }.toSet()
+    }
+
+    private fun getMockFor(method: Method, c: Class<*>): Any {
+        if (!c.isPrimitive && !Modifier.isFinal(c.modifiers)) {
+            return mock(c)
+        }
+        return when (c) {
+            NetworkCapabilities::class.java -> NetworkCapabilities()
+            LinkProperties::class.java -> LinkProperties()
+            LocalNetworkInfo::class.java -> LocalNetworkInfo(null)
+            Boolean::class.java -> false
+            Int::class.java -> 0
+            else -> fail("No mock set for parameter type $c used in $method")
+        }
+    }
+}
diff --git a/tests/unit/java/com/android/server/ConnectivityServiceTest.java b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
index 8526a9a..be7f2a3 100755
--- a/tests/unit/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/unit/java/com/android/server/ConnectivityServiceTest.java
@@ -2871,7 +2871,7 @@
         };
         final NetworkRequest request = mService.listenForNetwork(caps, messenger, binder,
                 NetworkCallback.FLAG_NONE, mContext.getOpPackageName(),
-                mContext.getAttributionTag());
+                mContext.getAttributionTag(), ~0 /* declaredMethodsFlag */);
         mService.releaseNetworkRequest(request);
         deathRecipient.get().binderDied();
         // Wait for the release message to be processed.
@@ -5407,7 +5407,7 @@
             mService.requestNetwork(Process.INVALID_UID, networkCapabilities,
                     NetworkRequest.Type.REQUEST.ordinal(), null, 0, null,
                     ConnectivityManager.TYPE_WIFI, NetworkCallback.FLAG_NONE,
-                    mContext.getPackageName(), getAttributionTag());
+                    mContext.getPackageName(), getAttributionTag(), ~0 /* declaredMethodsFlag */);
         });
 
         final NetworkRequest.Builder builder =
@@ -13655,7 +13655,8 @@
                     IllegalArgumentException.class,
                     () -> mService.requestNetwork(Process.INVALID_UID, nc, reqTypeInt, null, 0,
                             null, ConnectivityManager.TYPE_NONE, NetworkCallback.FLAG_NONE,
-                            mContext.getPackageName(), getAttributionTag())
+                            mContext.getPackageName(), getAttributionTag(),
+                            ~0 /* declaredMethodsFlag */)
             );
         }
     }
diff --git a/tests/unit/java/com/android/server/connectivityservice/CSDeclaredMethodsForCallbacksTest.kt b/tests/unit/java/com/android/server/connectivityservice/CSDeclaredMethodsForCallbacksTest.kt
new file mode 100644
index 0000000..cf990b1
--- /dev/null
+++ b/tests/unit/java/com/android/server/connectivityservice/CSDeclaredMethodsForCallbacksTest.kt
@@ -0,0 +1,141 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.connectivityservice
+
+import android.net.ConnectivityManager
+import android.net.ConnectivityManager.CALLBACK_CAP_CHANGED
+import android.net.ConnectivityManager.CALLBACK_IP_CHANGED
+import android.net.ConnectivityManager.CALLBACK_LOST
+import android.net.ConnectivityManager.NetworkCallback.DECLARED_METHODS_ALL
+import android.net.LinkAddress
+import android.net.NetworkCapabilities.NET_CAPABILITY_TEMPORARILY_NOT_METERED
+import android.net.NetworkRequest
+import android.os.Build
+import com.android.net.module.util.BitUtils.packBits
+import com.android.server.CSTest
+import com.android.server.ConnectivityService
+import com.android.server.defaultLp
+import com.android.server.defaultNc
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
+import com.android.testutils.RecorderCallback.CallbackEntry
+import com.android.testutils.TestableNetworkCallback
+import com.android.testutils.tryTest
+import java.util.concurrent.atomic.AtomicInteger
+import org.junit.Before
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.mockito.Mockito.any
+import org.mockito.Mockito.anyInt
+import org.mockito.Mockito.doAnswer
+import org.mockito.Mockito.spy
+
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.R)
+@DevSdkIgnoreRunner.MonitorThreadLeak
+@RunWith(DevSdkIgnoreRunner::class)
+class CSDeclaredMethodsForCallbacksTest : CSTest() {
+    private val mockedCallbackFlags = AtomicInteger(DECLARED_METHODS_ALL)
+    private lateinit var wrappedService: ConnectivityService
+
+    private val instrumentedCm by lazy { ConnectivityManager(context, wrappedService) }
+
+    @Before
+    fun setUpWrappedService() {
+        // Mock the callback flags set by ConnectivityManager when calling ConnectivityService, to
+        // simulate methods not being overridden
+        wrappedService = spy(service)
+        doAnswer { inv ->
+            service.requestNetwork(
+                inv.getArgument(0),
+                inv.getArgument(1),
+                inv.getArgument(2),
+                inv.getArgument(3),
+                inv.getArgument(4),
+                inv.getArgument(5),
+                inv.getArgument(6),
+                inv.getArgument(7),
+                inv.getArgument(8),
+                inv.getArgument(9),
+                mockedCallbackFlags.get())
+        }.`when`(wrappedService).requestNetwork(
+            anyInt(),
+            any(),
+            anyInt(),
+            any(),
+            anyInt(),
+            any(),
+            anyInt(),
+            anyInt(),
+            any(),
+            any(),
+            anyInt()
+        )
+        doAnswer { inv ->
+            service.listenForNetwork(
+                inv.getArgument(0),
+                inv.getArgument(1),
+                inv.getArgument(2),
+                inv.getArgument(3),
+                inv.getArgument(4),
+                inv.getArgument(5),
+                mockedCallbackFlags.get()
+            )
+        }.`when`(wrappedService)
+            .listenForNetwork(any(), any(), any(), anyInt(), any(), any(), anyInt())
+    }
+
+    @Test
+    fun testCallbacksAreFiltered() {
+        val requestCb = TestableNetworkCallback()
+        val listenCb = TestableNetworkCallback()
+        mockedCallbackFlags.withFlags(CALLBACK_IP_CHANGED, CALLBACK_LOST) {
+            instrumentedCm.requestNetwork(NetworkRequest.Builder().build(), requestCb)
+        }
+        mockedCallbackFlags.withFlags(CALLBACK_CAP_CHANGED) {
+            instrumentedCm.registerNetworkCallback(NetworkRequest.Builder().build(), listenCb)
+        }
+
+        with(Agent()) {
+            connect()
+            sendLinkProperties(defaultLp().apply {
+                addLinkAddress(LinkAddress("fe80:db8::123/64"))
+            })
+            sendNetworkCapabilities(defaultNc().apply {
+                addCapability(NET_CAPABILITY_TEMPORARILY_NOT_METERED)
+            })
+            disconnect()
+        }
+        waitForIdle()
+
+        // Only callbacks for the corresponding flags are called
+        requestCb.expect<CallbackEntry.LinkPropertiesChanged>()
+        requestCb.expect<CallbackEntry.Lost>()
+        requestCb.assertNoCallback(timeoutMs = 0L)
+
+        listenCb.expect<CallbackEntry.CapabilitiesChanged>()
+        listenCb.assertNoCallback(timeoutMs = 0L)
+    }
+}
+
+private fun AtomicInteger.withFlags(vararg flags: Int, action: () -> Unit) {
+    tryTest {
+        set(packBits(flags).toInt())
+        action()
+    } cleanup {
+        set(DECLARED_METHODS_ALL)
+    }
+}