Merge "Add unit tests related to data accounting for VPNs with one underlying network."
diff --git a/core/java/android/net/ConnectivityManager.java b/core/java/android/net/ConnectivityManager.java
index b7ba970..e5802c2 100644
--- a/core/java/android/net/ConnectivityManager.java
+++ b/core/java/android/net/ConnectivityManager.java
@@ -38,7 +38,6 @@
 import android.os.Build.VERSION_CODES;
 import android.os.Bundle;
 import android.os.Handler;
-import android.os.HandlerThread;
 import android.os.IBinder;
 import android.os.INetworkActivityListener;
 import android.os.INetworkManagementService;
@@ -75,6 +74,9 @@
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.Executor;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.RejectedExecutionException;
 
 /**
  * Class that answers queries about the state of network connectivity. It also
@@ -1813,23 +1815,26 @@
         public static final int MIN_INTERVAL = 10;
 
         private final Network mNetwork;
-        private final PacketKeepaliveCallback mCallback;
-        private final Looper mLooper;
-        private final Messenger mMessenger;
+        private final ISocketKeepaliveCallback mCallback;
+        private final ExecutorService mExecutor;
 
         private volatile Integer mSlot;
 
-        void stopLooper() {
-            mLooper.quit();
-        }
-
         @UnsupportedAppUsage
         public void stop() {
             try {
-                mService.stopKeepalive(mNetwork, mSlot);
-            } catch (RemoteException e) {
-                Log.e(TAG, "Error stopping packet keepalive: ", e);
-                stopLooper();
+                mExecutor.execute(() -> {
+                    try {
+                        if (mSlot != null) {
+                            mService.stopKeepalive(mNetwork, mSlot);
+                        }
+                    } catch (RemoteException e) {
+                        Log.e(TAG, "Error stopping packet keepalive: ", e);
+                        throw e.rethrowFromSystemServer();
+                    }
+                });
+            } catch (RejectedExecutionException e) {
+                // The internal executor has already stopped due to previous event.
             }
         }
 
@@ -1837,40 +1842,43 @@
             Preconditions.checkNotNull(network, "network cannot be null");
             Preconditions.checkNotNull(callback, "callback cannot be null");
             mNetwork = network;
-            mCallback = callback;
-            HandlerThread thread = new HandlerThread(TAG);
-            thread.start();
-            mLooper = thread.getLooper();
-            mMessenger = new Messenger(new Handler(mLooper) {
+            mExecutor = Executors.newSingleThreadExecutor();
+            mCallback = new ISocketKeepaliveCallback.Stub() {
                 @Override
-                public void handleMessage(Message message) {
-                    switch (message.what) {
-                        case NetworkAgent.EVENT_SOCKET_KEEPALIVE:
-                            int error = message.arg2;
-                            try {
-                                if (error == SUCCESS) {
-                                    if (mSlot == null) {
-                                        mSlot = message.arg1;
-                                        mCallback.onStarted();
-                                    } else {
-                                        mSlot = null;
-                                        stopLooper();
-                                        mCallback.onStopped();
-                                    }
-                                } else {
-                                    stopLooper();
-                                    mCallback.onError(error);
-                                }
-                            } catch (Exception e) {
-                                Log.e(TAG, "Exception in keepalive callback(" + error + ")", e);
-                            }
-                            break;
-                        default:
-                            Log.e(TAG, "Unhandled message " + Integer.toHexString(message.what));
-                            break;
-                    }
+                public void onStarted(int slot) {
+                    Binder.withCleanCallingIdentity(() ->
+                            mExecutor.execute(() -> {
+                                mSlot = slot;
+                                callback.onStarted();
+                            }));
                 }
-            });
+
+                @Override
+                public void onStopped() {
+                    Binder.withCleanCallingIdentity(() ->
+                            mExecutor.execute(() -> {
+                                mSlot = null;
+                                callback.onStopped();
+                            }));
+                    mExecutor.shutdown();
+                }
+
+                @Override
+                public void onError(int error) {
+                    Binder.withCleanCallingIdentity(() ->
+                            mExecutor.execute(() -> {
+                                mSlot = null;
+                                callback.onError(error);
+                            }));
+                    mExecutor.shutdown();
+                }
+
+                @Override
+                public void onDataReceived() {
+                    // PacketKeepalive is only used for Nat-T keepalive and as such does not invoke
+                    // this callback when data is received.
+                }
+            };
         }
     }
 
@@ -1887,12 +1895,11 @@
             InetAddress srcAddr, int srcPort, InetAddress dstAddr) {
         final PacketKeepalive k = new PacketKeepalive(network, callback);
         try {
-            mService.startNattKeepalive(network, intervalSeconds, k.mMessenger, new Binder(),
+            mService.startNattKeepalive(network, intervalSeconds, k.mCallback,
                     srcAddr.getHostAddress(), srcPort, dstAddr.getHostAddress());
         } catch (RemoteException e) {
             Log.e(TAG, "Error starting packet keepalive: ", e);
-            k.stopLooper();
-            return null;
+            throw e.rethrowFromSystemServer();
         }
         return k;
     }
@@ -2805,24 +2812,7 @@
          *         {@link #TETHER_ERROR_PROVISION_FAILED}, or
          *         {@link #TETHER_ERROR_ENTITLEMENT_UNKONWN}.
          */
-        void onEntitlementResult(@EntitlementResultCode int resultCode);
-    }
-
-    /**
-     * @removed
-     * @deprecated This API would be removed when all of caller has been updated.
-     * */
-    @Deprecated
-    public abstract static class TetheringEntitlementValueListener  {
-        /**
-         * Called to notify entitlement result.
-         *
-         * @param resultCode a int value of entitlement result. It may be one of
-         *         {@link #TETHER_ERROR_NO_ERROR},
-         *         {@link #TETHER_ERROR_PROVISION_FAILED}, or
-         *         {@link #TETHER_ERROR_ENTITLEMENT_UNKONWN}.
-         */
-        public void onEntitlementResult(int resultCode) {}
+        void onTetheringEntitlementResult(@EntitlementResultCode int resultCode);
     }
 
     /**
@@ -2855,7 +2845,7 @@
             protected void onReceiveResult(int resultCode, Bundle resultData) {
                 Binder.withCleanCallingIdentity(() ->
                             executor.execute(() -> {
-                                listener.onEntitlementResult(resultCode);
+                                listener.onTetheringEntitlementResult(resultCode);
                             }));
             }
         };
@@ -2871,31 +2861,6 @@
     }
 
     /**
-     * @removed
-     * @deprecated This API would be removed when all of caller has been updated.
-     * */
-    @Deprecated
-    public void getLatestTetheringEntitlementValue(int type, boolean showEntitlementUi,
-            @NonNull final TetheringEntitlementValueListener listener, @Nullable Handler handler) {
-        Preconditions.checkNotNull(listener, "TetheringEntitlementValueListener cannot be null.");
-        ResultReceiver wrappedListener = new ResultReceiver(handler) {
-            @Override
-            protected void onReceiveResult(int resultCode, Bundle resultData) {
-                listener.onEntitlementResult(resultCode);
-            }
-        };
-
-        try {
-            String pkgName = mContext.getOpPackageName();
-            Log.i(TAG, "getLatestTetheringEntitlementValue:" + pkgName);
-            mService.getLatestTetheringEntitlementResult(type, wrappedListener,
-                    showEntitlementUi, pkgName);
-        } catch (RemoteException e) {
-            throw e.rethrowFromSystemServer();
-        }
-    }
-
-    /**
      * Report network connectivity status.  This is currently used only
      * to alter status bar UI.
      * <p>This method requires the caller to hold the permission
@@ -3158,11 +3123,11 @@
         }
     }
 
-    /** {@hide} */
+    /** {@hide} - returns the factory serial number */
     @UnsupportedAppUsage
-    public void registerNetworkFactory(Messenger messenger, String name) {
+    public int registerNetworkFactory(Messenger messenger, String name) {
         try {
-            mService.registerNetworkFactory(messenger, name);
+            return mService.registerNetworkFactory(messenger, name);
         } catch (RemoteException e) {
             throw e.rethrowFromSystemServer();
         }
@@ -3178,6 +3143,10 @@
         }
     }
 
+    // TODO : remove this method. It is a stopgap measure to help sheperding a number
+    // of dependent changes that would conflict throughout the automerger graph. Having this
+    // temporarily helps with the process of going through with all these dependent changes across
+    // the entire tree.
     /**
      * @hide
      * Register a NetworkAgent with ConnectivityService.
@@ -3185,8 +3154,20 @@
      */
     public int registerNetworkAgent(Messenger messenger, NetworkInfo ni, LinkProperties lp,
             NetworkCapabilities nc, int score, NetworkMisc misc) {
+        return registerNetworkAgent(messenger, ni, lp, nc, score, misc,
+                NetworkFactory.SerialNumber.NONE);
+    }
+
+    /**
+     * @hide
+     * Register a NetworkAgent with ConnectivityService.
+     * @return NetID corresponding to NetworkAgent.
+     */
+    public int registerNetworkAgent(Messenger messenger, NetworkInfo ni, LinkProperties lp,
+            NetworkCapabilities nc, int score, NetworkMisc misc, int factorySerialNumber) {
         try {
-            return mService.registerNetworkAgent(messenger, ni, lp, nc, score, misc);
+            return mService.registerNetworkAgent(messenger, ni, lp, nc, score, misc,
+                    factorySerialNumber);
         } catch (RemoteException e) {
             throw e.rethrowFromSystemServer();
         }
diff --git a/core/java/android/net/DnsResolver.java b/core/java/android/net/DnsResolver.java
index d3bc3e6..93b8cf8 100644
--- a/core/java/android/net/DnsResolver.java
+++ b/core/java/android/net/DnsResolver.java
@@ -22,11 +22,11 @@
 import static android.os.MessageQueue.OnFileDescriptorEventListener.EVENT_ERROR;
 import static android.os.MessageQueue.OnFileDescriptorEventListener.EVENT_INPUT;
 
+import android.annotation.CallbackExecutor;
 import android.annotation.IntDef;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
-import android.os.Handler;
-import android.os.MessageQueue;
+import android.os.Looper;
 import android.system.ErrnoException;
 import android.util.Log;
 
@@ -37,8 +37,7 @@
 import java.net.UnknownHostException;
 import java.util.ArrayList;
 import java.util.List;
-import java.util.function.Consumer;
-
+import java.util.concurrent.Executor;
 
 /**
  * Dns resolver class for asynchronous dns querying
@@ -81,66 +80,137 @@
     public static final int FLAG_NO_CACHE_STORE = 1 << 1;
     public static final int FLAG_NO_CACHE_LOOKUP = 1 << 2;
 
-    private static final int DNS_RAW_RESPONSE = 1;
-
     private static final int NETID_UNSET = 0;
 
     private static final DnsResolver sInstance = new DnsResolver();
 
     /**
-     * listener for receiving raw answers
-     */
-    public interface RawAnswerListener {
-        /**
-         * {@code byte[]} is {@code null} if query timed out
-         */
-        void onAnswer(@Nullable byte[] answer);
-    }
-
-    /**
-     * listener for receiving parsed answers
-     */
-    public interface InetAddressAnswerListener {
-        /**
-         * Will be called exactly once with all the answers to the query.
-         * size of addresses will be zero if no available answer could be parsed.
-         */
-        void onAnswer(@NonNull List<InetAddress> addresses);
-    }
-
-    /**
      * Get instance for DnsResolver
      */
-    public static DnsResolver getInstance() {
+    public static @NonNull DnsResolver getInstance() {
         return sInstance;
     }
 
     private DnsResolver() {}
 
     /**
-     * Pass in a blob and corresponding setting,
-     * get a blob back asynchronously with the entire raw answer.
+     * Answer parser for parsing raw answers
+     *
+     * @param <T> The type of the parsed answer
+     */
+    public interface AnswerParser<T> {
+        /**
+         * Creates a <T> answer by parsing the given raw answer.
+         *
+         * @param rawAnswer the raw answer to be parsed
+         * @return a parsed <T> answer
+         * @throws ParseException if parsing failed
+         */
+        @NonNull T parse(@NonNull byte[] rawAnswer) throws ParseException;
+    }
+
+    /**
+     * Base class for answer callbacks
+     *
+     * @param <T> The type of the parsed answer
+     */
+    public abstract static class AnswerCallback<T> {
+        /** @hide */
+        public final AnswerParser<T> parser;
+
+        public AnswerCallback(@NonNull AnswerParser<T> parser) {
+            this.parser = parser;
+        };
+
+        /**
+         * Success response to
+         * {@link android.net.DnsResolver#query query()}.
+         *
+         * Invoked when the answer to a query was successfully parsed.
+         *
+         * @param answer parsed answer to the query.
+         *
+         * {@see android.net.DnsResolver#query query()}
+         */
+        public abstract void onAnswer(@NonNull T answer);
+
+        /**
+         * Error response to
+         * {@link android.net.DnsResolver#query query()}.
+         *
+         * Invoked when there is no valid answer to
+         * {@link android.net.DnsResolver#query query()}
+         *
+         * @param exception a {@link ParseException} object with additional
+         *    detail regarding the failure
+         */
+        public abstract void onParseException(@NonNull ParseException exception);
+
+        /**
+         * Error response to
+         * {@link android.net.DnsResolver#query query()}.
+         *
+         * Invoked if an error happens when
+         * issuing the DNS query or receiving the result.
+         * {@link android.net.DnsResolver#query query()}
+         *
+         * @param exception an {@link ErrnoException} object with additional detail
+         *    regarding the failure
+         */
+        public abstract void onQueryException(@NonNull ErrnoException exception);
+    }
+
+    /**
+     * Callback for receiving raw answers
+     */
+    public abstract static class RawAnswerCallback extends AnswerCallback<byte[]> {
+        public RawAnswerCallback() {
+            super(rawAnswer -> rawAnswer);
+        }
+    }
+
+    /**
+     * Callback for receiving parsed {@link InetAddress} answers
+     *
+     * Note that if the answer does not contain any IP addresses,
+     * onAnswer will be called with an empty list.
+     */
+    public abstract static class InetAddressAnswerCallback
+            extends AnswerCallback<List<InetAddress>> {
+        public InetAddressAnswerCallback() {
+            super(rawAnswer -> new DnsAddressAnswer(rawAnswer).getAddresses());
+        }
+    }
+
+    /**
+     * Send a raw DNS query.
+     * The answer will be provided asynchronously through the provided {@link AnswerCallback}.
      *
      * @param network {@link Network} specifying which network for querying.
      *         {@code null} for query on default network.
      * @param query blob message
      * @param flags flags as a combination of the FLAGS_* constants
-     * @param handler {@link Handler} to specify the thread
-     *         upon which the {@link RawAnswerListener} will be invoked.
-     * @param listener a {@link RawAnswerListener} which will be called to notify the caller
+     * @param executor The {@link Executor} that the callback should be executed on.
+     * @param callback an {@link AnswerCallback} which will be called to notify the caller
      *         of the result of dns query.
      */
-    public void query(@Nullable Network network, @NonNull byte[] query, @QueryFlag int flags,
-            @NonNull Handler handler, @NonNull RawAnswerListener listener) throws ErrnoException {
-        final FileDescriptor queryfd = resNetworkSend((network != null
+    public <T> void query(@Nullable Network network, @NonNull byte[] query, @QueryFlag int flags,
+            @NonNull @CallbackExecutor Executor executor, @NonNull AnswerCallback<T> callback) {
+        final FileDescriptor queryfd;
+        try {
+            queryfd = resNetworkSend((network != null
                 ? network.netId : NETID_UNSET), query, query.length, flags);
-        registerFDListener(handler.getLooper().getQueue(), queryfd,
-                answerbuf -> listener.onAnswer(answerbuf));
+        } catch (ErrnoException e) {
+            callback.onQueryException(e);
+            return;
+        }
+
+        registerFDListener(executor, queryfd, callback);
     }
 
     /**
-     * Pass in a domain name and corresponding setting,
-     * get a blob back asynchronously with the entire raw answer.
+     * Send a DNS query with the specified name, class and query type.
+     * The answer will be provided asynchronously through the provided {@link AnswerCallback}.
      *
      * @param network {@link Network} specifying which network for querying.
      *         {@code null} for query on default network.
@@ -148,74 +218,53 @@
      * @param nsClass dns class as one of the CLASS_* constants
      * @param nsType dns resource record (RR) type as one of the TYPE_* constants
      * @param flags flags as a combination of the FLAGS_* constants
-     * @param handler {@link Handler} to specify the thread
-     *         upon which the {@link RawAnswerListener} will be invoked.
-     * @param listener a {@link RawAnswerListener} which will be called to notify the caller
+     * @param executor The {@link Executor} that the callback should be executed on.
+     * @param callback an {@link AnswerCallback} which will be called to notify the caller
      *         of the result of dns query.
      */
-    public void query(@Nullable Network network, @NonNull String domain, @QueryClass int nsClass,
-            @QueryType int nsType, @QueryFlag int flags,
-            @NonNull Handler handler, @NonNull RawAnswerListener listener) throws ErrnoException {
-        final FileDescriptor queryfd = resNetworkQuery((network != null
-                ? network.netId : NETID_UNSET), domain, nsClass, nsType, flags);
-        registerFDListener(handler.getLooper().getQueue(), queryfd,
-                answerbuf -> listener.onAnswer(answerbuf));
+    public <T> void query(@Nullable Network network, @NonNull String domain,
+            @QueryClass int nsClass, @QueryType int nsType, @QueryFlag int flags,
+            @NonNull @CallbackExecutor Executor executor, @NonNull AnswerCallback<T> callback) {
+        final FileDescriptor queryfd;
+        try {
+            queryfd = resNetworkQuery((network != null
+                    ? network.netId : NETID_UNSET), domain, nsClass, nsType, flags);
+        } catch (ErrnoException e) {
+            callback.onQueryException(e);
+            return;
+        }
+        registerFDListener(executor, queryfd, callback);
     }
 
-    /**
-     * Pass in a domain name and corresponding setting,
-     * get back a set of InetAddresses asynchronously.
-     *
-     * @param network {@link Network} specifying which network for querying.
-     *         {@code null} for query on default network.
-     * @param domain domain name for querying
-     * @param flags flags as a combination of the FLAGS_* constants
-     * @param handler {@link Handler} to specify the thread
-     *         upon which the {@link InetAddressAnswerListener} will be invoked.
-     * @param listener an {@link InetAddressAnswerListener} which will be called to
-     *         notify the caller of the result of dns query.
-     *
-     */
-    public void query(@Nullable Network network, @NonNull String domain, @QueryFlag int flags,
-            @NonNull Handler handler, @NonNull InetAddressAnswerListener listener)
-            throws ErrnoException {
-        final FileDescriptor v4fd = resNetworkQuery((network != null
-                ? network.netId : NETID_UNSET), domain, CLASS_IN, TYPE_A, flags);
-        final FileDescriptor v6fd = resNetworkQuery((network != null
-                ? network.netId : NETID_UNSET), domain, CLASS_IN, TYPE_AAAA, flags);
-
-        final InetAddressAnswerAccumulator accmulator =
-                new InetAddressAnswerAccumulator(2, listener);
-        final Consumer<byte[]> consumer = answerbuf ->
-                accmulator.accumulate(parseAnswers(answerbuf));
-
-        registerFDListener(handler.getLooper().getQueue(), v4fd, consumer);
-        registerFDListener(handler.getLooper().getQueue(), v6fd, consumer);
-    }
-
-    private void registerFDListener(@NonNull MessageQueue queue,
-            @NonNull FileDescriptor queryfd, @NonNull Consumer<byte[]> answerConsumer) {
-        queue.addOnFileDescriptorEventListener(
+    private <T> void registerFDListener(@NonNull Executor executor,
+            @NonNull FileDescriptor queryfd, @NonNull AnswerCallback<T> answerCallback) {
+        Looper.getMainLooper().getQueue().addOnFileDescriptorEventListener(
                 queryfd,
                 FD_EVENTS,
                 (fd, events) -> {
-                    byte[] answerbuf = null;
-                    try {
-                    // TODO: Implement result function in Java side instead of using JNI
-                    //       Because JNI method close fd prior than unregistering fd on
-                    //       event listener.
-                        answerbuf = resNetworkResult(fd);
-                    } catch (ErrnoException e) {
-                        Log.e(TAG, "resNetworkResult:" + e.toString());
-                    }
-                    answerConsumer.accept(answerbuf);
+                    executor.execute(() -> {
+                        byte[] answerbuf = null;
+                        try {
+                            answerbuf = resNetworkResult(fd);
+                        } catch (ErrnoException e) {
+                            Log.e(TAG, "resNetworkResult:" + e.toString());
+                            answerCallback.onQueryException(e);
+                            return;
+                        }
 
+                        try {
+                            answerCallback.onAnswer(
+                                    answerCallback.parser.parse(answerbuf));
+                        } catch (ParseException e) {
+                            answerCallback.onParseException(e);
+                        }
+                    });
                     // Unregister this fd listener
                     return 0;
                 });
     }
 
-    private class DnsAddressAnswer extends DnsPacket {
+    private static class DnsAddressAnswer extends DnsPacket {
         private static final String TAG = "DnsResolver.DnsAddressAnswer";
         private static final boolean DBG = false;
 
@@ -226,12 +275,6 @@
             if ((mHeader.flags & (1 << 15)) == 0) {
                 throw new ParseException("Not an answer packet");
             }
-            if (mHeader.rcode != 0) {
-                throw new ParseException("Response error, rcode:" + mHeader.rcode);
-            }
-            if (mHeader.getRecordCount(ANSECTION) == 0) {
-                throw new ParseException("No available answer");
-            }
             if (mHeader.getRecordCount(QDSECTION) == 0) {
                 throw new ParseException("No question found");
             }
@@ -241,6 +284,8 @@
 
         public @NonNull List<InetAddress> getAddresses() {
             final List<InetAddress> results = new ArrayList<InetAddress>();
+            if (mHeader.getRecordCount(ANSECTION) == 0) return results;
+
             for (final DnsRecord ansSec : mRecords[ANSECTION]) {
                 // Only support A and AAAA, also ignore answers if query type != answer type.
                 int nsType = ansSec.nsType;
@@ -259,34 +304,4 @@
         }
     }
 
-    private @Nullable List<InetAddress> parseAnswers(@Nullable byte[] data) {
-        try {
-            return (data == null) ? null : new DnsAddressAnswer(data).getAddresses();
-        } catch (DnsPacket.ParseException e) {
-            Log.e(TAG, "Parse answer fail " + e.getMessage());
-            return null;
-        }
-    }
-
-    private class InetAddressAnswerAccumulator {
-        private final List<InetAddress> mAllAnswers;
-        private final InetAddressAnswerListener mAnswerListener;
-        private final int mTargetAnswerCount;
-        private int mReceivedAnswerCount = 0;
-
-        InetAddressAnswerAccumulator(int size, @NonNull InetAddressAnswerListener listener) {
-            mTargetAnswerCount = size;
-            mAllAnswers = new ArrayList<>();
-            mAnswerListener = listener;
-        }
-
-        public void accumulate(@Nullable List<InetAddress> answer) {
-            if (null != answer) {
-                mAllAnswers.addAll(answer);
-            }
-            if (++mReceivedAnswerCount == mTargetAnswerCount) {
-                mAnswerListener.onAnswer(mAllAnswers);
-            }
-        }
-    }
 }
diff --git a/core/java/android/net/IConnectivityManager.aidl b/core/java/android/net/IConnectivityManager.aidl
index 2df4e75..24e6a85 100644
--- a/core/java/android/net/IConnectivityManager.aidl
+++ b/core/java/android/net/IConnectivityManager.aidl
@@ -27,6 +27,7 @@
 import android.net.NetworkQuotaInfo;
 import android.net.NetworkRequest;
 import android.net.NetworkState;
+import android.net.ISocketKeepaliveCallback;
 import android.net.ProxyInfo;
 import android.os.Bundle;
 import android.os.IBinder;
@@ -150,14 +151,14 @@
 
     void setAirplaneMode(boolean enable);
 
-    void registerNetworkFactory(in Messenger messenger, in String name);
+    int registerNetworkFactory(in Messenger messenger, in String name);
 
     boolean requestBandwidthUpdate(in Network network);
 
     void unregisterNetworkFactory(in Messenger messenger);
 
     int registerNetworkAgent(in Messenger messenger, in NetworkInfo ni, in LinkProperties lp,
-            in NetworkCapabilities nc, int score, in NetworkMisc misc);
+            in NetworkCapabilities nc, int score, in NetworkMisc misc, in int factorySerialNumber);
 
     NetworkRequest requestNetwork(in NetworkCapabilities networkCapabilities,
             in Messenger messenger, int timeoutSec, in IBinder binder, int legacy);
@@ -194,15 +195,15 @@
 
     void factoryReset();
 
-    void startNattKeepalive(in Network network, int intervalSeconds, in Messenger messenger,
-            in IBinder binder, String srcAddr, int srcPort, String dstAddr);
+    void startNattKeepalive(in Network network, int intervalSeconds,
+            in ISocketKeepaliveCallback cb, String srcAddr, int srcPort, String dstAddr);
 
     void startNattKeepaliveWithFd(in Network network, in FileDescriptor fd, int resourceId,
-            int intervalSeconds, in Messenger messenger, in IBinder binder, String srcAddr,
+            int intervalSeconds, in ISocketKeepaliveCallback cb, String srcAddr,
             String dstAddr);
 
     void startTcpKeepalive(in Network network, in FileDescriptor fd, int intervalSeconds,
-            in Messenger messenger, in IBinder binder);
+            in ISocketKeepaliveCallback cb);
 
     void stopKeepalive(in Network network, int slot);
 
@@ -219,4 +220,6 @@
 
     void registerTetheringEventCallback(ITetheringEventCallback callback, String callerPkg);
     void unregisterTetheringEventCallback(ITetheringEventCallback callback, String callerPkg);
+
+    IBinder startOrGetTestNetworkService();
 }
diff --git a/core/java/android/net/NattSocketKeepalive.java b/core/java/android/net/NattSocketKeepalive.java
index 88631ae..84da294 100644
--- a/core/java/android/net/NattSocketKeepalive.java
+++ b/core/java/android/net/NattSocketKeepalive.java
@@ -17,7 +17,6 @@
 package android.net;
 
 import android.annotation.NonNull;
-import android.os.Binder;
 import android.os.RemoteException;
 import android.util.Log;
 
@@ -52,24 +51,30 @@
 
     @Override
     void startImpl(int intervalSec) {
-        try {
-            mService.startNattKeepaliveWithFd(mNetwork, mFd, mResourceId, intervalSec, mMessenger,
-                    new Binder(), mSource.getHostAddress(), mDestination.getHostAddress());
-        } catch (RemoteException e) {
-            Log.e(TAG, "Error starting packet keepalive: ", e);
-            stopLooper();
-        }
+        mExecutor.execute(() -> {
+            try {
+                mService.startNattKeepaliveWithFd(mNetwork, mFd, mResourceId, intervalSec,
+                        mCallback,
+                        mSource.getHostAddress(), mDestination.getHostAddress());
+            } catch (RemoteException e) {
+                Log.e(TAG, "Error starting socket keepalive: ", e);
+                throw e.rethrowFromSystemServer();
+            }
+        });
     }
 
     @Override
     void stopImpl() {
-        try {
-            if (mSlot != null) {
-                mService.stopKeepalive(mNetwork, mSlot);
+        mExecutor.execute(() -> {
+            try {
+                if (mSlot != null) {
+                    mService.stopKeepalive(mNetwork, mSlot);
+                }
+            } catch (RemoteException e) {
+                Log.e(TAG, "Error stopping socket keepalive: ", e);
+                throw e.rethrowFromSystemServer();
             }
-        } catch (RemoteException e) {
-            Log.e(TAG, "Error stopping packet keepalive: ", e);
-            stopLooper();
-        }
+        });
+
     }
 }
diff --git a/core/java/android/net/NetworkAgent.java b/core/java/android/net/NetworkAgent.java
index 7bef690..b55f6ba 100644
--- a/core/java/android/net/NetworkAgent.java
+++ b/core/java/android/net/NetworkAgent.java
@@ -57,6 +57,7 @@
     private static final long BW_REFRESH_MIN_WIN_MS = 500;
     private boolean mPollLceScheduled = false;
     private AtomicBoolean mPollLcePending = new AtomicBoolean(false);
+    public final int mFactorySerialNumber;
 
     private static final int BASE = Protocol.BASE_NETWORK_AGENT;
 
@@ -212,16 +213,31 @@
      */
     public static final int CMD_PREVENT_AUTOMATIC_RECONNECT = BASE + 15;
 
+    // TODO : remove these two constructors. They are a stopgap measure to help sheperding a number
+    // of dependent changes that would conflict throughout the automerger graph. Having these
+    // temporarily helps with the process of going through with all these dependent changes across
+    // the entire tree.
     public NetworkAgent(Looper looper, Context context, String logTag, NetworkInfo ni,
             NetworkCapabilities nc, LinkProperties lp, int score) {
-        this(looper, context, logTag, ni, nc, lp, score, null);
+        this(looper, context, logTag, ni, nc, lp, score, null, NetworkFactory.SerialNumber.NONE);
+    }
+    public NetworkAgent(Looper looper, Context context, String logTag, NetworkInfo ni,
+            NetworkCapabilities nc, LinkProperties lp, int score, NetworkMisc misc) {
+        this(looper, context, logTag, ni, nc, lp, score, misc, NetworkFactory.SerialNumber.NONE);
     }
 
     public NetworkAgent(Looper looper, Context context, String logTag, NetworkInfo ni,
-            NetworkCapabilities nc, LinkProperties lp, int score, NetworkMisc misc) {
+            NetworkCapabilities nc, LinkProperties lp, int score, int factorySerialNumber) {
+        this(looper, context, logTag, ni, nc, lp, score, null, factorySerialNumber);
+    }
+
+    public NetworkAgent(Looper looper, Context context, String logTag, NetworkInfo ni,
+            NetworkCapabilities nc, LinkProperties lp, int score, NetworkMisc misc,
+            int factorySerialNumber) {
         super(looper);
         LOG_TAG = logTag;
         mContext = context;
+        mFactorySerialNumber = factorySerialNumber;
         if (ni == null || nc == null || lp == null) {
             throw new IllegalArgumentException();
         }
@@ -230,7 +246,8 @@
         ConnectivityManager cm = (ConnectivityManager)mContext.getSystemService(
                 Context.CONNECTIVITY_SERVICE);
         netId = cm.registerNetworkAgent(new Messenger(this), new NetworkInfo(ni),
-                new LinkProperties(lp), new NetworkCapabilities(nc), score, misc);
+                new LinkProperties(lp), new NetworkCapabilities(nc), score, misc,
+                factorySerialNumber);
     }
 
     @Override
diff --git a/core/java/android/net/SocketKeepalive.java b/core/java/android/net/SocketKeepalive.java
index 07728be..0e768df 100644
--- a/core/java/android/net/SocketKeepalive.java
+++ b/core/java/android/net/SocketKeepalive.java
@@ -20,13 +20,8 @@
 import android.annotation.IntRange;
 import android.annotation.NonNull;
 import android.annotation.Nullable;
-import android.os.Handler;
-import android.os.HandlerThread;
-import android.os.Looper;
-import android.os.Message;
-import android.os.Messenger;
-import android.os.Process;
-import android.util.Log;
+import android.os.Binder;
+import android.os.RemoteException;
 
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
@@ -152,10 +147,9 @@
 
     @NonNull final IConnectivityManager mService;
     @NonNull final Network mNetwork;
-    @NonNull private final Executor mExecutor;
-    @NonNull private final SocketKeepalive.Callback mCallback;
-    @NonNull private final Looper mLooper;
-    @NonNull final Messenger mMessenger;
+    @NonNull final Executor mExecutor;
+    @NonNull final ISocketKeepaliveCallback mCallback;
+    // TODO: remove slot since mCallback could be used to identify which keepalive to stop.
     @Nullable Integer mSlot;
 
     SocketKeepalive(@NonNull IConnectivityManager service, @NonNull Network network,
@@ -163,53 +157,53 @@
         mService = service;
         mNetwork = network;
         mExecutor = executor;
-        mCallback = callback;
-        // TODO: 1. Use other thread modeling instead of create one thread for every instance to
-        //          reduce the memory cost.
-        //       2. support restart.
-        //       3. Fix race condition which caused by rapidly start and stop.
-        HandlerThread thread = new HandlerThread(TAG, Process.THREAD_PRIORITY_BACKGROUND
-                + Process.THREAD_PRIORITY_LESS_FAVORABLE);
-        thread.start();
-        mLooper = thread.getLooper();
-        mMessenger = new Messenger(new Handler(mLooper) {
+        mCallback = new ISocketKeepaliveCallback.Stub() {
             @Override
-            public void handleMessage(Message message) {
-                switch (message.what) {
-                    case NetworkAgent.EVENT_SOCKET_KEEPALIVE:
-                        final int status = message.arg2;
-                        try {
-                            if (status == SUCCESS) {
-                                if (mSlot == null) {
-                                    mSlot = message.arg1;
-                                    mExecutor.execute(() -> mCallback.onStarted());
-                                } else {
-                                    mSlot = null;
-                                    stopLooper();
-                                    mExecutor.execute(() -> mCallback.onStopped());
-                                }
-                            } else if (status == DATA_RECEIVED) {
-                                stopLooper();
-                                mExecutor.execute(() -> mCallback.onDataReceived());
-                            } else {
-                                stopLooper();
-                                mExecutor.execute(() -> mCallback.onError(status));
-                            }
-                        } catch (Exception e) {
-                            Log.e(TAG, "Exception in keepalive callback(" + status + ")", e);
-                        }
-                        break;
-                    default:
-                        Log.e(TAG, "Unhandled message " + Integer.toHexString(message.what));
-                        break;
-                }
+            public void onStarted(int slot) {
+                Binder.withCleanCallingIdentity(() ->
+                        mExecutor.execute(() -> {
+                            mSlot = slot;
+                            callback.onStarted();
+                        }));
             }
-        });
+
+            @Override
+            public void onStopped() {
+                Binder.withCleanCallingIdentity(() ->
+                        executor.execute(() -> {
+                            mSlot = null;
+                            callback.onStopped();
+                        }));
+            }
+
+            @Override
+            public void onError(int error) {
+                Binder.withCleanCallingIdentity(() ->
+                        executor.execute(() -> {
+                            mSlot = null;
+                            callback.onError(error);
+                        }));
+            }
+
+            @Override
+            public void onDataReceived() {
+                Binder.withCleanCallingIdentity(() ->
+                        executor.execute(() -> {
+                            mSlot = null;
+                            callback.onDataReceived();
+                        }));
+            }
+        };
     }
 
     /**
      * Request that keepalive be started with the given {@code intervalSec}. See
-     * {@link SocketKeepalive}.
+     * {@link SocketKeepalive}. If the remote binder dies, or the binder call throws an exception
+     * when invoking start or stop of the {@link SocketKeepalive}, a {@link RemoteException} will be
+     * thrown into the {@code executor}. This is typically not important to catch because the remote
+     * party is the system, so if it is not in shape to communicate through binder the system is
+     * probably going down anyway. If the caller cares regardless, it can use a custom
+     * {@link Executor} to catch the {@link RemoteException}.
      *
      * @param intervalSec The target interval in seconds between keepalive packet transmissions.
      *                    The interval should be between 10 seconds and 3600 seconds, otherwise
@@ -222,12 +216,6 @@
 
     abstract void startImpl(int intervalSec);
 
-    /** @hide */
-    protected void stopLooper() {
-        // TODO: remove this after changing thread modeling.
-        mLooper.quit();
-    }
-
     /**
      * Requests that keepalive be stopped. The application must wait for {@link Callback#onStopped}
      * before using the object. See {@link SocketKeepalive}.
@@ -245,7 +233,6 @@
     @Override
     public final void close() {
         stop();
-        stopLooper();
     }
 
     /**
@@ -259,7 +246,8 @@
         public void onStopped() {}
         /** An error occurred. */
         public void onError(@ErrorCode int error) {}
-        /** The keepalive on a TCP socket was stopped because the socket received data. */
+        /** The keepalive on a TCP socket was stopped because the socket received data. This is
+         * never called for UDP sockets. */
         public void onDataReceived() {}
     }
 }
diff --git a/core/java/android/net/TcpSocketKeepalive.java b/core/java/android/net/TcpSocketKeepalive.java
index f691a0d..26cc8ff 100644
--- a/core/java/android/net/TcpSocketKeepalive.java
+++ b/core/java/android/net/TcpSocketKeepalive.java
@@ -17,7 +17,6 @@
 package android.net;
 
 import android.annotation.NonNull;
-import android.os.Binder;
 import android.os.RemoteException;
 import android.util.Log;
 
@@ -56,24 +55,28 @@
      */
     @Override
     void startImpl(int intervalSec) {
-        try {
-            final FileDescriptor fd = mSocket.getFileDescriptor$();
-            mService.startTcpKeepalive(mNetwork, fd, intervalSec, mMessenger, new Binder());
-        } catch (RemoteException e) {
-            Log.e(TAG, "Error starting packet keepalive: ", e);
-            stopLooper();
-        }
+        mExecutor.execute(() -> {
+            try {
+                final FileDescriptor fd = mSocket.getFileDescriptor$();
+                mService.startTcpKeepalive(mNetwork, fd, intervalSec, mCallback);
+            } catch (RemoteException e) {
+                Log.e(TAG, "Error starting packet keepalive: ", e);
+                throw e.rethrowFromSystemServer();
+            }
+        });
     }
 
     @Override
     void stopImpl() {
-        try {
-            if (mSlot != null) {
-                mService.stopKeepalive(mNetwork, mSlot);
+        mExecutor.execute(() -> {
+            try {
+                if (mSlot != null) {
+                    mService.stopKeepalive(mNetwork, mSlot);
+                }
+            } catch (RemoteException e) {
+                Log.e(TAG, "Error stopping packet keepalive: ", e);
+                throw e.rethrowFromSystemServer();
             }
-        } catch (RemoteException e) {
-            Log.e(TAG, "Error stopping packet keepalive: ", e);
-            stopLooper();
-        }
+        });
     }
 }
diff --git a/services/core/java/com/android/server/ConnectivityService.java b/services/core/java/com/android/server/ConnectivityService.java
index dbfc327..4416b4d 100644
--- a/services/core/java/com/android/server/ConnectivityService.java
+++ b/services/core/java/com/android/server/ConnectivityService.java
@@ -41,7 +41,6 @@
 import static android.net.NetworkPolicyManager.RULE_NONE;
 import static android.net.NetworkPolicyManager.uidRulesToString;
 import static android.net.shared.NetworkMonitorUtils.isValidationRequired;
-import static android.net.shared.NetworkParcelableUtil.toStableParcelable;
 import static android.os.Process.INVALID_UID;
 import static android.system.OsConstants.IPPROTO_TCP;
 import static android.system.OsConstants.IPPROTO_UDP;
@@ -73,6 +72,7 @@
 import android.net.INetworkPolicyListener;
 import android.net.INetworkPolicyManager;
 import android.net.INetworkStatsService;
+import android.net.ISocketKeepaliveCallback;
 import android.net.ITetheringEventCallback;
 import android.net.InetAddresses;
 import android.net.IpPrefix;
@@ -84,6 +84,7 @@
 import android.net.NetworkAgent;
 import android.net.NetworkCapabilities;
 import android.net.NetworkConfig;
+import android.net.NetworkFactory;
 import android.net.NetworkInfo;
 import android.net.NetworkInfo.DetailedState;
 import android.net.NetworkMisc;
@@ -298,6 +299,15 @@
     private INetworkPolicyManager mPolicyManager;
     private NetworkPolicyManagerInternal mPolicyManagerInternal;
 
+    /**
+     * TestNetworkService (lazily) created upon first usage. Locked to prevent creation of multiple
+     * instances.
+     */
+    @GuardedBy("mTNSLock")
+    private TestNetworkService mTNS;
+
+    private final Object mTNSLock = new Object();
+
     private String mCurrentTcpBufferSizes;
 
     private static final SparseArray<String> sMagicDecoderRing = MessageUtils.findMessageNames(
@@ -2892,8 +2902,17 @@
                 for (NetworkRequestInfo nri : mNetworkRequests.values()) {
                     if (nri.request.isListen()) continue;
                     NetworkAgentInfo nai = getNetworkForRequest(nri.request.requestId);
-                    ac.sendMessage(android.net.NetworkFactory.CMD_REQUEST_NETWORK,
-                            (nai != null ? nai.getCurrentScore() : 0), 0, nri.request);
+                    final int score;
+                    final int serial;
+                    if (nai != null) {
+                        score = nai.getCurrentScore();
+                        serial = nai.factorySerialNumber;
+                    } else {
+                        score = 0;
+                        serial = NetworkFactory.SerialNumber.NONE;
+                    }
+                    ac.sendMessage(android.net.NetworkFactory.CMD_REQUEST_NETWORK, score, serial,
+                            nri.request);
                 }
             } else {
                 loge("Error connecting NetworkFactory");
@@ -2991,7 +3010,7 @@
             NetworkAgentInfo currentNetwork = getNetworkForRequest(request.requestId);
             if (currentNetwork != null && currentNetwork.network.netId == nai.network.netId) {
                 clearNetworkForRequest(request.requestId);
-                sendUpdatedScoreToFactories(request, 0);
+                sendUpdatedScoreToFactories(request, null);
             }
         }
         nai.clearLingerState();
@@ -3068,7 +3087,7 @@
         }
         rematchAllNetworksAndRequests(null, 0);
         if (nri.request.isRequest() && getNetworkForRequest(nri.request.requestId) == null) {
-            sendUpdatedScoreToFactories(nri.request, 0);
+            sendUpdatedScoreToFactories(nri.request, null);
         }
     }
 
@@ -4843,11 +4862,14 @@
         public final String name;
         public final Messenger messenger;
         public final AsyncChannel asyncChannel;
+        public final int factorySerialNumber;
 
-        public NetworkFactoryInfo(String name, Messenger messenger, AsyncChannel asyncChannel) {
+        NetworkFactoryInfo(String name, Messenger messenger, AsyncChannel asyncChannel,
+                int factorySerialNumber) {
             this.name = name;
             this.messenger = messenger;
             this.asyncChannel = asyncChannel;
+            this.factorySerialNumber = factorySerialNumber;
         }
     }
 
@@ -5208,10 +5230,12 @@
     }
 
     @Override
-    public void registerNetworkFactory(Messenger messenger, String name) {
+    public int registerNetworkFactory(Messenger messenger, String name) {
         enforceConnectivityInternalPermission();
-        NetworkFactoryInfo nfi = new NetworkFactoryInfo(name, messenger, new AsyncChannel());
+        NetworkFactoryInfo nfi = new NetworkFactoryInfo(name, messenger, new AsyncChannel(),
+                NetworkFactory.SerialNumber.nextSerialNumber());
         mHandler.sendMessage(mHandler.obtainMessage(EVENT_REGISTER_NETWORK_FACTORY, nfi));
+        return nfi.factorySerialNumber;
     }
 
     private void handleRegisterNetworkFactory(NetworkFactoryInfo nfi) {
@@ -5316,9 +5340,35 @@
         return nri.request.requestId == mDefaultRequest.requestId;
     }
 
+    // TODO : remove this method. It's a stopgap measure to help sheperding a number of dependent
+    // changes that would conflict throughout the automerger graph. Having this method temporarily
+    // helps with the process of going through with all these dependent changes across the entire
+    // tree.
     public int registerNetworkAgent(Messenger messenger, NetworkInfo networkInfo,
             LinkProperties linkProperties, NetworkCapabilities networkCapabilities,
             int currentScore, NetworkMisc networkMisc) {
+        return registerNetworkAgent(messenger, networkInfo, linkProperties, networkCapabilities,
+                currentScore, networkMisc, NetworkFactory.SerialNumber.NONE);
+    }
+
+    /**
+     * Register a new agent with ConnectivityService to handle a network.
+     *
+     * @param messenger a messenger for ConnectivityService to contact the agent asynchronously.
+     * @param networkInfo the initial info associated with this network. It can be updated later :
+     *         see {@link #updateNetworkInfo}.
+     * @param linkProperties the initial link properties of this network. They can be updated
+     *         later : see {@link #updateLinkProperties}.
+     * @param networkCapabilities the initial capabilites of this network. They can be updated
+     *         later : see {@link #updateNetworkCapabilities}.
+     * @param currentScore the initial score of the network. See
+     *         {@link NetworkAgentInfo#getCurrentScore}.
+     * @param networkMisc metadata about the network. This is never updated.
+     * @param factorySerialNumber the serial number of the factory owning this NetworkAgent.
+     */
+    public int registerNetworkAgent(Messenger messenger, NetworkInfo networkInfo,
+            LinkProperties linkProperties, NetworkCapabilities networkCapabilities,
+            int currentScore, NetworkMisc networkMisc, int factorySerialNumber) {
         enforceConnectivityInternalPermission();
 
         LinkProperties lp = new LinkProperties(linkProperties);
@@ -5328,7 +5378,8 @@
         final NetworkCapabilities nc = new NetworkCapabilities(networkCapabilities);
         final NetworkAgentInfo nai = new NetworkAgentInfo(messenger, new AsyncChannel(),
                 new Network(reserveNetId()), new NetworkInfo(networkInfo), lp, nc, currentScore,
-                mContext, mTrackerHandler, new NetworkMisc(networkMisc), this, mNetd, mNMS);
+                mContext, mTrackerHandler, new NetworkMisc(networkMisc), this, mNetd, mNMS,
+                factorySerialNumber);
         // Make sure the network capabilities reflect what the agent info says.
         nai.networkCapabilities = mixInCapabilities(nai, nc);
         final String extraInfo = networkInfo.getExtraInfo();
@@ -5338,7 +5389,7 @@
         final long token = Binder.clearCallingIdentity();
         try {
             getNetworkStack().makeNetworkMonitor(
-                    toStableParcelable(nai.network), name, new NetworkMonitorCallbacks(nai));
+                    nai.network, name, new NetworkMonitorCallbacks(nai));
         } finally {
             Binder.restoreCallingIdentity(token);
         }
@@ -5755,17 +5806,23 @@
             NetworkRequest nr = nai.requestAt(i);
             // Don't send listening requests to factories. b/17393458
             if (nr.isListen()) continue;
-            sendUpdatedScoreToFactories(nr, nai.getCurrentScore());
+            sendUpdatedScoreToFactories(nr, nai);
         }
     }
 
-    private void sendUpdatedScoreToFactories(NetworkRequest networkRequest, int score) {
+    private void sendUpdatedScoreToFactories(NetworkRequest networkRequest, NetworkAgentInfo nai) {
+        int score = 0;
+        int serial = 0;
+        if (nai != null) {
+            score = nai.getCurrentScore();
+            serial = nai.factorySerialNumber;
+        }
         if (VDBG || DDBG){
             log("sending new Min Network Score(" + score + "): " + networkRequest.toString());
         }
         for (NetworkFactoryInfo nfi : mNetworkFactoryInfos.values()) {
-            nfi.asyncChannel.sendMessage(android.net.NetworkFactory.CMD_REQUEST_NETWORK, score, 0,
-                    networkRequest);
+            nfi.asyncChannel.sendMessage(android.net.NetworkFactory.CMD_REQUEST_NETWORK, score,
+                    serial, networkRequest);
         }
     }
 
@@ -6043,7 +6100,7 @@
                     // TODO - this could get expensive if we have a lot of requests for this
                     // network.  Think about if there is a way to reduce this.  Push
                     // netid->request mapping to each factory?
-                    sendUpdatedScoreToFactories(nri.request, score);
+                    sendUpdatedScoreToFactories(nri.request, newNetwork);
                     if (isDefaultRequest(nri)) {
                         isNewDefault = true;
                         oldDefaultNetwork = currentNetwork;
@@ -6067,7 +6124,7 @@
                 newNetwork.removeRequest(nri.request.requestId);
                 if (currentNetwork == newNetwork) {
                     clearNetworkForRequest(nri.request.requestId);
-                    sendUpdatedScoreToFactories(nri.request, 0);
+                    sendUpdatedScoreToFactories(nri.request, null);
                 } else {
                     Slog.wtf(TAG, "BUG: Removing request " + nri.request.requestId + " from " +
                             newNetwork.name() +
@@ -6642,32 +6699,32 @@
     }
 
     @Override
-    public void startNattKeepalive(Network network, int intervalSeconds, Messenger messenger,
-            IBinder binder, String srcAddr, int srcPort, String dstAddr) {
+    public void startNattKeepalive(Network network, int intervalSeconds,
+            ISocketKeepaliveCallback cb, String srcAddr, int srcPort, String dstAddr) {
         enforceKeepalivePermission();
         mKeepaliveTracker.startNattKeepalive(
                 getNetworkAgentInfoForNetwork(network),
-                intervalSeconds, messenger, binder,
+                intervalSeconds, cb,
                 srcAddr, srcPort, dstAddr, NattSocketKeepalive.NATT_PORT);
     }
 
     @Override
     public void startNattKeepaliveWithFd(Network network, FileDescriptor fd, int resourceId,
-            int intervalSeconds, Messenger messenger, IBinder binder, String srcAddr,
+            int intervalSeconds, ISocketKeepaliveCallback cb, String srcAddr,
             String dstAddr) {
         enforceKeepalivePermission();
         mKeepaliveTracker.startNattKeepalive(
                 getNetworkAgentInfoForNetwork(network), fd, resourceId,
-                intervalSeconds, messenger, binder,
+                intervalSeconds, cb,
                 srcAddr, dstAddr, NattSocketKeepalive.NATT_PORT);
     }
 
     @Override
     public void startTcpKeepalive(Network network, FileDescriptor fd, int intervalSeconds,
-            Messenger messenger, IBinder binder) {
+            ISocketKeepaliveCallback cb) {
         enforceKeepalivePermission();
         mKeepaliveTracker.startTcpKeepalive(
-                getNetworkAgentInfoForNetwork(network), fd, intervalSeconds, messenger, binder);
+                getNetworkAgentInfoForNetwork(network), fd, intervalSeconds, cb);
     }
 
     @Override
@@ -6910,4 +6967,22 @@
             return vpn != null && vpn.getLockdown();
         }
     }
+
+    /**
+     * Returns a IBinder to a TestNetworkService. Will be lazily created as needed.
+     *
+     * <p>The TestNetworkService must be run in the system server due to TUN creation.
+     */
+    @Override
+    public IBinder startOrGetTestNetworkService() {
+        synchronized (mTNSLock) {
+            TestNetworkService.enforceTestNetworkPermissions(mContext);
+
+            if (mTNS == null) {
+                mTNS = new TestNetworkService(mContext, mNMS);
+            }
+
+            return mTNS;
+        }
+    }
 }
diff --git a/services/core/java/com/android/server/connectivity/KeepaliveTracker.java b/services/core/java/com/android/server/connectivity/KeepaliveTracker.java
index cc4c173..35d6860 100644
--- a/services/core/java/com/android/server/connectivity/KeepaliveTracker.java
+++ b/services/core/java/com/android/server/connectivity/KeepaliveTracker.java
@@ -21,8 +21,8 @@
 import static android.net.NetworkAgent.CMD_REMOVE_KEEPALIVE_PACKET_FILTER;
 import static android.net.NetworkAgent.CMD_START_SOCKET_KEEPALIVE;
 import static android.net.NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE;
-import static android.net.NetworkAgent.EVENT_SOCKET_KEEPALIVE;
 import static android.net.SocketKeepalive.BINDER_DIED;
+import static android.net.SocketKeepalive.DATA_RECEIVED;
 import static android.net.SocketKeepalive.ERROR_INVALID_INTERVAL;
 import static android.net.SocketKeepalive.ERROR_INVALID_IP_ADDRESS;
 import static android.net.SocketKeepalive.ERROR_INVALID_NETWORK;
@@ -34,6 +34,7 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.net.ISocketKeepaliveCallback;
 import android.net.KeepalivePacketData;
 import android.net.NattKeepalivePacketData;
 import android.net.NetworkAgent;
@@ -47,7 +48,6 @@
 import android.os.Handler;
 import android.os.IBinder;
 import android.os.Message;
-import android.os.Messenger;
 import android.os.Process;
 import android.os.RemoteException;
 import android.system.ErrnoException;
@@ -99,8 +99,7 @@
      */
     class KeepaliveInfo implements IBinder.DeathRecipient {
         // Bookkeeping data.
-        private final Messenger mMessenger;
-        private final IBinder mBinder;
+        private final ISocketKeepaliveCallback mCallback;
         private final int mUid;
         private final int mPid;
         private final NetworkAgentInfo mNai;
@@ -124,15 +123,13 @@
         private static final int STARTED = 3;
         private int mStartedState = NOT_STARTED;
 
-        KeepaliveInfo(@NonNull Messenger messenger,
-                @NonNull IBinder binder,
+        KeepaliveInfo(@NonNull ISocketKeepaliveCallback callback,
                 @NonNull NetworkAgentInfo nai,
                 @NonNull KeepalivePacketData packet,
                 int interval,
                 int type,
                 @NonNull FileDescriptor fd) {
-            mMessenger = messenger;
-            mBinder = binder;
+            mCallback = callback;
             mPid = Binder.getCallingPid();
             mUid = Binder.getCallingUid();
 
@@ -143,7 +140,7 @@
             mFd = fd;
 
             try {
-                mBinder.linkToDeath(this, 0);
+                mCallback.asBinder().linkToDeath(this, 0);
             } catch (RemoteException e) {
                 binderDied();
             }
@@ -176,22 +173,14 @@
                     + " ]";
         }
 
-        /** Sends a message back to the application via its SocketKeepalive.Callback. */
-        void notifyMessenger(int slot, int err) {
-            if (DBG) {
-                Log.d(TAG, "notify keepalive " + mSlot + " on " + mNai.network + " for " + err);
-            }
-            KeepaliveTracker.this.notifyMessenger(mMessenger, slot, err);
-        }
-
         /** Called when the application process is killed. */
         public void binderDied() {
             stop(BINDER_DIED);
         }
 
         void unlinkDeathRecipient() {
-            if (mBinder != null) {
-                mBinder.unlinkToDeath(this, 0);
+            if (mCallback != null) {
+                mCallback.asBinder().unlinkToDeath(this, 0);
             }
         }
 
@@ -283,9 +272,23 @@
                     Log.wtf(TAG, "Stopping keepalive with unknown type: " + mType);
                 }
             }
-            // TODO: at the moment we unconditionally return failure here. In cases where the
-            // NetworkAgent is alive, should we ask it to reply, so it can return failure?
-            notifyMessenger(mSlot, reason);
+
+            if (reason == SUCCESS) {
+                try {
+                    mCallback.onStopped();
+                } catch (RemoteException e) {
+                    Log.w(TAG, "Discarded onStop callback: " + reason);
+                }
+            } else if (reason == DATA_RECEIVED) {
+                try {
+                    mCallback.onDataReceived();
+                } catch (RemoteException e) {
+                    Log.w(TAG, "Discarded onDataReceived callback: " + reason);
+                }
+            } else {
+                notifyErrorCallback(mCallback, reason);
+            }
+
             unlinkDeathRecipient();
         }
 
@@ -294,16 +297,12 @@
         }
     }
 
-    void notifyMessenger(Messenger messenger, int slot, int err) {
-        Message message = Message.obtain();
-        message.what = EVENT_SOCKET_KEEPALIVE;
-        message.arg1 = slot;
-        message.arg2 = err;
-        message.obj = null;
+    void notifyErrorCallback(ISocketKeepaliveCallback cb, int error) {
+        if (DBG) Log.w(TAG, "Sending onError(" + error + ") callback");
         try {
-            messenger.send(message);
+            cb.onError(error);
         } catch (RemoteException e) {
-            // Process died?
+            Log.w(TAG, "Discarded onError(" + error + ") callback");
         }
     }
 
@@ -414,7 +413,11 @@
             // Keepalive successfully started.
             if (DBG) Log.d(TAG, "Started keepalive " + slot + " on " + nai.name());
             ki.mStartedState = KeepaliveInfo.STARTED;
-            ki.notifyMessenger(slot, reason);
+            try {
+                ki.mCallback.onStarted(slot);
+            } catch (RemoteException e) {
+                Log.w(TAG, "Discarded onStarted(" + slot + ") callback");
+            }
         } else {
             // Keepalive successfully stopped, or error.
             ki.mStartedState = KeepaliveInfo.NOT_STARTED;
@@ -436,14 +439,13 @@
      **/
     public void startNattKeepalive(@Nullable NetworkAgentInfo nai,
             int intervalSeconds,
-            @NonNull Messenger messenger,
-            @NonNull IBinder binder,
+            @NonNull ISocketKeepaliveCallback cb,
             @NonNull String srcAddrString,
             int srcPort,
             @NonNull String dstAddrString,
             int dstPort) {
         if (nai == null) {
-            notifyMessenger(messenger, NO_KEEPALIVE, ERROR_INVALID_NETWORK);
+            notifyErrorCallback(cb, ERROR_INVALID_NETWORK);
             return;
         }
 
@@ -452,7 +454,7 @@
             srcAddress = NetworkUtils.numericToInetAddress(srcAddrString);
             dstAddress = NetworkUtils.numericToInetAddress(dstAddrString);
         } catch (IllegalArgumentException e) {
-            notifyMessenger(messenger, NO_KEEPALIVE, ERROR_INVALID_IP_ADDRESS);
+            notifyErrorCallback(cb, ERROR_INVALID_IP_ADDRESS);
             return;
         }
 
@@ -461,11 +463,12 @@
             packet = NattKeepalivePacketData.nattKeepalivePacket(
                     srcAddress, srcPort, dstAddress, NATT_PORT);
         } catch (InvalidPacketException e) {
-            notifyMessenger(messenger, NO_KEEPALIVE, e.error);
+            notifyErrorCallback(cb, e.error);
             return;
         }
-        KeepaliveInfo ki = new KeepaliveInfo(messenger, binder, nai, packet, intervalSeconds,
+        KeepaliveInfo ki = new KeepaliveInfo(cb, nai, packet, intervalSeconds,
                 KeepaliveInfo.TYPE_NATT, null);
+        Log.d(TAG, "Created keepalive: " + ki.toString());
         mConnectivityServiceHandler.obtainMessage(
                 NetworkAgent.CMD_START_SOCKET_KEEPALIVE, ki).sendToTarget();
     }
@@ -483,10 +486,9 @@
     public void startTcpKeepalive(@Nullable NetworkAgentInfo nai,
             @NonNull FileDescriptor fd,
             int intervalSeconds,
-            @NonNull Messenger messenger,
-            @NonNull IBinder binder) {
+            @NonNull ISocketKeepaliveCallback cb) {
         if (nai == null) {
-            notifyMessenger(messenger, NO_KEEPALIVE, ERROR_INVALID_NETWORK);
+            notifyErrorCallback(cb, ERROR_INVALID_NETWORK);
             return;
         }
 
@@ -500,10 +502,10 @@
             } catch (ErrnoException e1) {
                 Log.e(TAG, "Couldn't move fd out of repair mode after failure to start keepalive");
             }
-            notifyMessenger(messenger, NO_KEEPALIVE, e.error);
+            notifyErrorCallback(cb, e.error);
             return;
         }
-        KeepaliveInfo ki = new KeepaliveInfo(messenger, binder, nai, packet, intervalSeconds,
+        KeepaliveInfo ki = new KeepaliveInfo(cb, nai, packet, intervalSeconds,
                 KeepaliveInfo.TYPE_TCP, fd);
         Log.d(TAG, "Created keepalive: " + ki.toString());
         mConnectivityServiceHandler.obtainMessage(CMD_START_SOCKET_KEEPALIVE, ki).sendToTarget();
@@ -520,14 +522,13 @@
             @Nullable FileDescriptor fd,
             int resourceId,
             int intervalSeconds,
-            @NonNull Messenger messenger,
-            @NonNull IBinder binder,
+            @NonNull ISocketKeepaliveCallback cb,
             @NonNull String srcAddrString,
             @NonNull String dstAddrString,
             int dstPort) {
         // Ensure that the socket is created by IpSecService.
         if (!isNattKeepaliveSocketValid(fd, resourceId)) {
-            notifyMessenger(messenger, NO_KEEPALIVE, ERROR_INVALID_SOCKET);
+            notifyErrorCallback(cb, ERROR_INVALID_SOCKET);
         }
 
         // Get src port to adopt old API.
@@ -536,11 +537,11 @@
             final SocketAddress srcSockAddr = Os.getsockname(fd);
             srcPort = ((InetSocketAddress) srcSockAddr).getPort();
         } catch (ErrnoException e) {
-            notifyMessenger(messenger, NO_KEEPALIVE, ERROR_INVALID_SOCKET);
+            notifyErrorCallback(cb, ERROR_INVALID_SOCKET);
         }
 
         // Forward request to old API.
-        startNattKeepalive(nai, intervalSeconds, messenger, binder, srcAddrString, srcPort,
+        startNattKeepalive(nai, intervalSeconds, cb, srcAddrString, srcPort,
                 dstAddrString, dstPort);
     }
 
diff --git a/services/core/java/com/android/server/connectivity/NetworkAgentInfo.java b/services/core/java/com/android/server/connectivity/NetworkAgentInfo.java
index 65eb158..8f2825c 100644
--- a/services/core/java/com/android/server/connectivity/NetworkAgentInfo.java
+++ b/services/core/java/com/android/server/connectivity/NetworkAgentInfo.java
@@ -238,6 +238,8 @@
     public final Messenger messenger;
     public final AsyncChannel asyncChannel;
 
+    public final int factorySerialNumber;
+
     // Used by ConnectivityService to keep track of 464xlat.
     public final Nat464Xlat clatd;
 
@@ -253,7 +255,7 @@
     public NetworkAgentInfo(Messenger messenger, AsyncChannel ac, Network net, NetworkInfo info,
             LinkProperties lp, NetworkCapabilities nc, int score, Context context, Handler handler,
             NetworkMisc misc, ConnectivityService connService, INetd netd,
-            INetworkManagementService nms) {
+            INetworkManagementService nms, int factorySerialNumber) {
         this.messenger = messenger;
         asyncChannel = ac;
         network = net;
@@ -266,6 +268,7 @@
         mContext = context;
         mHandler = handler;
         networkMisc = misc;
+        this.factorySerialNumber = factorySerialNumber;
     }
 
     /**
diff --git a/services/core/java/com/android/server/connectivity/NetworkNotificationManager.java b/services/core/java/com/android/server/connectivity/NetworkNotificationManager.java
index 053da0d..828a1e5 100644
--- a/services/core/java/com/android/server/connectivity/NetworkNotificationManager.java
+++ b/services/core/java/com/android/server/connectivity/NetworkNotificationManager.java
@@ -28,6 +28,7 @@
 import android.content.res.Resources;
 import android.net.wifi.WifiInfo;
 import android.os.UserHandle;
+import android.telephony.AccessNetworkConstants.TransportType;
 import android.telephony.TelephonyManager;
 import android.text.TextUtils;
 import android.util.Slog;
@@ -92,7 +93,7 @@
         return -1;
     }
 
-    private static String getTransportName(int transportType) {
+    private static String getTransportName(@TransportType int transportType) {
         Resources r = Resources.getSystem();
         String[] networkTypes = r.getStringArray(R.array.network_switch_type_name);
         try {
diff --git a/tests/net/Android.bp b/tests/net/Android.bp
index 2539c0f..c62d85e 100644
--- a/tests/net/Android.bp
+++ b/tests/net/Android.bp
@@ -13,7 +13,6 @@
         "mockito-target-minus-junit4",
         "platform-test-annotations",
         "services.core",
-        "services.ipmemorystore",
         "services.net",
     ],
     libs: [
diff --git a/tests/net/java/android/net/IpMemoryStoreTest.java b/tests/net/java/android/net/IpMemoryStoreTest.java
index 57ecc8f..18c6768 100644
--- a/tests/net/java/android/net/IpMemoryStoreTest.java
+++ b/tests/net/java/android/net/IpMemoryStoreTest.java
@@ -16,6 +16,9 @@
 
 package android.net;
 
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doAnswer;
+
 import android.content.Context;
 
 import androidx.test.filters.SmallTest;
@@ -33,13 +36,25 @@
     @Mock
     Context mMockContext;
     @Mock
+    NetworkStackClient mNetworkStackClient;
+    @Mock
     IIpMemoryStore mMockService;
     IpMemoryStore mStore;
 
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
-        mStore = new IpMemoryStore(mMockContext, mMockService);
+        doAnswer(invocation -> {
+            ((IIpMemoryStoreCallbacks) invocation.getArgument(0))
+                    .onIpMemoryStoreFetched(mMockService);
+            return null;
+        }).when(mNetworkStackClient).fetchIpMemoryStore(any());
+        mStore = new IpMemoryStore(mMockContext) {
+            @Override
+            protected NetworkStackClient getNetworkStackClient() {
+                return mNetworkStackClient;
+            }
+        };
     }
 
     @Test
diff --git a/tests/net/java/com/android/server/ConnectivityServiceTest.java b/tests/net/java/com/android/server/ConnectivityServiceTest.java
index 83ef62e..3efdfd9 100644
--- a/tests/net/java/com/android/server/ConnectivityServiceTest.java
+++ b/tests/net/java/com/android/server/ConnectivityServiceTest.java
@@ -60,10 +60,10 @@
 import static android.net.NetworkPolicyManager.RULE_NONE;
 import static android.net.NetworkPolicyManager.RULE_REJECT_ALL;
 import static android.net.NetworkPolicyManager.RULE_REJECT_METERED;
-import static android.net.shared.NetworkParcelableUtil.fromStableParcelable;
 
 import static com.android.internal.util.TestUtils.waitForIdleHandler;
 import static com.android.internal.util.TestUtils.waitForIdleLooper;
+import static com.android.internal.util.TestUtils.waitForIdleSerialExecutor;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -88,6 +88,7 @@
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
+import android.annotation.NonNull;
 import android.app.NotificationManager;
 import android.app.PendingIntent;
 import android.content.BroadcastReceiver;
@@ -123,7 +124,6 @@
 import android.net.NetworkInfo;
 import android.net.NetworkInfo.DetailedState;
 import android.net.NetworkMisc;
-import android.net.NetworkParcelable;
 import android.net.NetworkRequest;
 import android.net.NetworkSpecifier;
 import android.net.NetworkStackClient;
@@ -190,6 +190,8 @@
 
 import java.net.Inet4Address;
 import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.Socket;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -206,6 +208,7 @@
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Consumer;
 import java.util.function.Predicate;
 
 /**
@@ -497,8 +500,7 @@
                 fail(e.getMessage());
             }
 
-            final ArgumentCaptor<NetworkParcelable> nmNetworkCaptor =
-                    ArgumentCaptor.forClass(NetworkParcelable.class);
+            final ArgumentCaptor<Network> nmNetworkCaptor = ArgumentCaptor.forClass(Network.class);
             final ArgumentCaptor<INetworkMonitorCallbacks> nmCbCaptor =
                     ArgumentCaptor.forClass(INetworkMonitorCallbacks.class);
             doNothing().when(mNetworkStack).makeNetworkMonitor(
@@ -508,7 +510,7 @@
 
             mNetworkAgent = new NetworkAgent(mHandlerThread.getLooper(), mServiceContext,
                     "Mock-" + typeName, mNetworkInfo, mNetworkCapabilities,
-                    linkProperties, mScore, new NetworkMisc()) {
+                    linkProperties, mScore, new NetworkMisc(), NetworkFactory.SerialNumber.NONE) {
                 @Override
                 public void unwanted() { mDisconnected.open(); }
 
@@ -538,8 +540,7 @@
                 }
             };
 
-            assertEquals(
-                    mNetworkAgent.netId, fromStableParcelable(nmNetworkCaptor.getValue()).netId);
+            assertEquals(mNetworkAgent.netId, nmNetworkCaptor.getValue().netId);
             mNmCallbacks = nmCbCaptor.getValue();
 
             try {
@@ -743,7 +744,7 @@
     /**
      * A NetworkFactory that allows tests to wait until any in-flight NetworkRequest add or remove
      * operations have been processed. Before ConnectivityService can add or remove any requests,
-     * the factory must be told to expect those operations by calling expectAddRequests or
+     * the factory must be told to expect those operations by calling expectAddRequestsWithScores or
      * expectRemoveRequests.
      */
     private static class MockNetworkFactory extends NetworkFactory {
@@ -752,14 +753,16 @@
         private final AtomicBoolean mNetworkStarted = new AtomicBoolean(false);
 
         // Used to expect that requests be removed or added on a separate thread, without sleeping.
-        // Callers can call either expectAddRequests() or expectRemoveRequests() exactly once, then
-        // cause some other thread to add or remove requests, then call waitForRequests(). We can
-        // either expect requests to be added or removed, but not both, because CountDownLatch can
-        // only count in one direction.
-        private CountDownLatch mExpectations;
+        // Callers can call either expectAddRequestsWithScores() or expectRemoveRequests() exactly
+        // once, then cause some other thread to add or remove requests, then call
+        // waitForRequests().
+        // It is not possible to wait for both add and remove requests. When adding, the queue
+        // contains the expected score. When removing, the value is unused, all matters is the
+        // number of objects in the queue.
+        private final LinkedBlockingQueue<Integer> mExpectations;
 
         // Whether we are currently expecting requests to be added or removed. Valid only if
-        // mExpectations is non-null.
+        // mExpectations is non-empty.
         private boolean mExpectingAdditions;
 
         // Used to collect the networks requests managed by this factory. This is a duplicate of
@@ -769,6 +772,7 @@
         public MockNetworkFactory(Looper looper, Context context, String logTag,
                 NetworkCapabilities filter) {
             super(looper, context, logTag, filter);
+            mExpectations = new LinkedBlockingQueue<>();
         }
 
         public int getMyRequestCount() {
@@ -800,38 +804,44 @@
         }
 
         @Override
-        protected void handleAddRequest(NetworkRequest request, int score) {
-            // If we're expecting anything, we must be expecting additions.
-            if (mExpectations != null && !mExpectingAdditions) {
-                fail("Can't add requests while expecting requests to be removed");
-            }
+        protected void handleAddRequest(NetworkRequest request, int score,
+                int factorySerialNumber) {
+            synchronized (mExpectations) {
+                final Integer expectedScore = mExpectations.poll(); // null if the queue is empty
 
-            // Add the request.
-            mNetworkRequests.put(request.requestId, request);
-            super.handleAddRequest(request, score);
+                assertNotNull("Added more requests than expected (" + request + " score : "
+                        + score + ")", expectedScore);
+                // If we're expecting anything, we must be expecting additions.
+                if (!mExpectingAdditions) {
+                    fail("Can't add requests while expecting requests to be removed");
+                }
+                if (expectedScore != score) {
+                    fail("Expected score was " + expectedScore + " but actual was " + score
+                            + " in added request");
+                }
 
-            // Reduce the number of request additions we're waiting for.
-            if (mExpectingAdditions) {
-                assertTrue("Added more requests than expected", mExpectations.getCount() > 0);
-                mExpectations.countDown();
+                // Add the request.
+                mNetworkRequests.put(request.requestId, request);
+                super.handleAddRequest(request, score, factorySerialNumber);
+                mExpectations.notify();
             }
         }
 
         @Override
         protected void handleRemoveRequest(NetworkRequest request) {
-            // If we're expecting anything, we must be expecting removals.
-            if (mExpectations != null && mExpectingAdditions) {
-                fail("Can't remove requests while expecting requests to be added");
-            }
+            synchronized (mExpectations) {
+                final Integer expectedScore = mExpectations.poll(); // null if the queue is empty
 
-            // Remove the request.
-            mNetworkRequests.remove(request.requestId);
-            super.handleRemoveRequest(request);
+                assertTrue("Removed more requests than expected", expectedScore != null);
+                // If we're expecting anything, we must be expecting removals.
+                if (mExpectingAdditions) {
+                    fail("Can't remove requests while expecting requests to be added");
+                }
 
-            // Reduce the number of request removals we're waiting for.
-            if (!mExpectingAdditions) {
-                assertTrue("Removed more requests than expected", mExpectations.getCount() > 0);
-                mExpectations.countDown();
+                // Remove the request.
+                mNetworkRequests.remove(request.requestId);
+                super.handleRemoveRequest(request);
+                mExpectations.notify();
             }
         }
 
@@ -841,35 +851,42 @@
         }
 
         private void assertNoExpectations() {
-            if (mExpectations != null) {
-                fail("Can't add expectation, " + mExpectations.getCount() + " already pending");
+            if (mExpectations.size() != 0) {
+                fail("Can't add expectation, " + mExpectations.size() + " already pending");
             }
         }
 
-        // Expects that count requests will be added.
-        public void expectAddRequests(final int count) {
+        // Expects that requests with the specified scores will be added.
+        public void expectAddRequestsWithScores(final int... scores) {
             assertNoExpectations();
             mExpectingAdditions = true;
-            mExpectations = new CountDownLatch(count);
+            for (int score : scores) {
+                mExpectations.add(score);
+            }
         }
 
         // Expects that count requests will be removed.
         public void expectRemoveRequests(final int count) {
             assertNoExpectations();
             mExpectingAdditions = false;
-            mExpectations = new CountDownLatch(count);
+            for (int i = 0; i < count; ++i) {
+                mExpectations.add(0); // For removals the score is ignored so any value will do.
+            }
         }
 
         // Waits for the expected request additions or removals to happen within a timeout.
         public void waitForRequests() throws InterruptedException {
-            assertNotNull("Nothing to wait for", mExpectations);
-            mExpectations.await(TIMEOUT_MS, TimeUnit.MILLISECONDS);
-            final long count = mExpectations.getCount();
+            final long deadline = SystemClock.elapsedRealtime() + TIMEOUT_MS;
+            synchronized (mExpectations) {
+                while (mExpectations.size() > 0 && SystemClock.elapsedRealtime() < deadline) {
+                    mExpectations.wait(deadline - SystemClock.elapsedRealtime());
+                }
+            }
+            final long count = mExpectations.size();
             final String msg = count + " requests still not " +
                     (mExpectingAdditions ? "added" : "removed") +
                     " after " + TIMEOUT_MS + " ms";
             assertEquals(msg, 0, count);
-            mExpectations = null;
         }
 
         public SparseArray<NetworkRequest> waitForNetworkRequests(final int count)
@@ -2326,6 +2343,12 @@
         callback.expectCallback(CallbackState.LOST, mEthernetNetworkAgent);
     }
 
+    private int[] makeIntArray(final int size, final int value) {
+        final int[] array = new int[size];
+        Arrays.fill(array, value);
+        return array;
+    }
+
     private void tryNetworkFactoryRequests(int capability) throws Exception {
         // Verify NOT_RESTRICTED is set appropriately
         final NetworkCapabilities nc = new NetworkRequest.Builder().addCapability(capability)
@@ -2347,7 +2370,7 @@
                 mServiceContext, "testFactory", filter);
         testFactory.setScoreFilter(40);
         ConditionVariable cv = testFactory.getNetworkStartedCV();
-        testFactory.expectAddRequests(1);
+        testFactory.expectAddRequestsWithScores(0);
         testFactory.register();
         testFactory.waitForNetworkRequests(1);
         int expectedRequestCount = 1;
@@ -2358,7 +2381,7 @@
             assertFalse(testFactory.getMyStartRequested());
             NetworkRequest request = new NetworkRequest.Builder().addCapability(capability).build();
             networkCallback = new NetworkCallback();
-            testFactory.expectAddRequests(1);
+            testFactory.expectAddRequestsWithScores(0);  // New request
             mCm.requestNetwork(request, networkCallback);
             expectedRequestCount++;
             testFactory.waitForNetworkRequests(expectedRequestCount);
@@ -2378,7 +2401,7 @@
         // When testAgent connects, ConnectivityService will re-send us all current requests with
         // the new score. There are expectedRequestCount such requests, and we must wait for all of
         // them.
-        testFactory.expectAddRequests(expectedRequestCount);
+        testFactory.expectAddRequestsWithScores(makeIntArray(expectedRequestCount, 50));
         testAgent.connect(false);
         testAgent.addCapability(capability);
         waitFor(cv);
@@ -2386,7 +2409,7 @@
         assertFalse(testFactory.getMyStartRequested());
 
         // Bring in a bunch of requests.
-        testFactory.expectAddRequests(10);
+        testFactory.expectAddRequestsWithScores(makeIntArray(10, 50));
         assertEquals(expectedRequestCount, testFactory.getMyRequestCount());
         ConnectivityManager.NetworkCallback[] networkCallbacks =
                 new ConnectivityManager.NetworkCallback[10];
@@ -2409,8 +2432,11 @@
 
         // Drop the higher scored network.
         cv = testFactory.getNetworkStartedCV();
+        // With the default network disconnecting, the requests are sent with score 0 to factories.
+        testFactory.expectAddRequestsWithScores(makeIntArray(expectedRequestCount, 0));
         testAgent.disconnect();
         waitFor(cv);
+        testFactory.waitForNetworkRequests(expectedRequestCount);
         assertEquals(expectedRequestCount, testFactory.getMyRequestCount());
         assertTrue(testFactory.getMyStartRequested());
 
@@ -2524,7 +2550,8 @@
         verifyActiveNetwork(TRANSPORT_CELLULAR);
     }
 
-    @Test
+    // TODO: deflake and re-enable
+    // @Test
     public void testPartialConnectivity() {
         // Register network callback.
         NetworkRequest request = new NetworkRequest.Builder()
@@ -3332,22 +3359,23 @@
         testFactory.setScoreFilter(40);
 
         // Register the factory and expect it to start looking for a network.
-        testFactory.expectAddRequests(1);
+        testFactory.expectAddRequestsWithScores(0);  // Score 0 as the request is not served yet.
         testFactory.register();
         testFactory.waitForNetworkRequests(1);
         assertTrue(testFactory.getMyStartRequested());
 
         // Bring up wifi. The factory stops looking for a network.
         mWiFiNetworkAgent = new MockNetworkAgent(TRANSPORT_WIFI);
-        testFactory.expectAddRequests(2);  // Because the default request changes score twice.
+        // Score 60 - 40 penalty for not validated yet, then 60 when it validates
+        testFactory.expectAddRequestsWithScores(20, 60);
         mWiFiNetworkAgent.connect(true);
-        testFactory.waitForNetworkRequests(1);
+        testFactory.waitForRequests();
         assertFalse(testFactory.getMyStartRequested());
 
         ContentResolver cr = mServiceContext.getContentResolver();
 
         // Turn on mobile data always on. The factory starts looking again.
-        testFactory.expectAddRequests(1);
+        testFactory.expectAddRequestsWithScores(0);  // Always on requests comes up with score 0
         setAlwaysOnNetworks(true);
         testFactory.waitForNetworkRequests(2);
         assertTrue(testFactory.getMyStartRequested());
@@ -3355,7 +3383,7 @@
         // Bring up cell data and check that the factory stops looking.
         assertLength(1, mCm.getAllNetworks());
         mCellNetworkAgent = new MockNetworkAgent(TRANSPORT_CELLULAR);
-        testFactory.expectAddRequests(2);  // Because the cell request changes score twice.
+        testFactory.expectAddRequestsWithScores(10, 50);  // Unvalidated, then validated
         mCellNetworkAgent.connect(true);
         cellNetworkCallback.expectAvailableThenValidatedCallbacks(mCellNetworkAgent);
         testFactory.waitForNetworkRequests(2);
@@ -3669,7 +3697,7 @@
         testFactory.setScoreFilter(40);
 
         // Register the factory and expect it to receive the default request.
-        testFactory.expectAddRequests(1);
+        testFactory.expectAddRequestsWithScores(0); // default request score is 0, not served yet
         testFactory.register();
         SparseArray<NetworkRequest> requests = testFactory.waitForNetworkRequests(1);
 
@@ -3677,7 +3705,7 @@
         int origRequestId = requests.valueAt(0).requestId;
 
         // Now file the test request and expect it.
-        testFactory.expectAddRequests(1);
+        testFactory.expectAddRequestsWithScores(0);
         mCm.requestNetwork(nr, networkCallback);
         requests = testFactory.waitForNetworkRequests(2); // have 2 requests at this point
 
@@ -3732,7 +3760,7 @@
             }
         }
 
-        private LinkedBlockingQueue<CallbackValue> mCallbacks = new LinkedBlockingQueue<>();
+        private final LinkedBlockingQueue<CallbackValue> mCallbacks = new LinkedBlockingQueue<>();
 
         @Override
         public void onStarted() {
@@ -3807,6 +3835,11 @@
         }
 
         private LinkedBlockingQueue<CallbackValue> mCallbacks = new LinkedBlockingQueue<>();
+        private final Executor mExecutor;
+
+        TestSocketKeepaliveCallback(@NonNull Executor executor) {
+            mExecutor = executor;
+        }
 
         @Override
         public void onStarted() {
@@ -3844,6 +3877,12 @@
         public void expectError(int error) {
             expectCallback(new CallbackValue(CallbackType.ON_ERROR, error));
         }
+
+        public void assertNoCallback() {
+            waitForIdleSerialExecutor(mExecutor, TIMEOUT_MS);
+            CallbackValue cv = mCallbacks.peek();
+            assertNull("Unexpected callback: " + cv, cv);
+        }
     }
 
     private Network connectKeepaliveNetwork(LinkProperties lp) {
@@ -3950,19 +3989,6 @@
         myNet = connectKeepaliveNetwork(lp);
         mWiFiNetworkAgent.setStartKeepaliveError(PacketKeepalive.SUCCESS);
 
-        // Check things work as expected when the keepalive is stopped and the network disconnects.
-        ka = mCm.startNattKeepalive(myNet, validKaInterval, callback, myIPv4, 12345, dstIPv4);
-        callback.expectStarted();
-        ka.stop();
-        mWiFiNetworkAgent.disconnect();
-        waitFor(mWiFiNetworkAgent.getDisconnectedCV());
-        waitForIdle();
-        callback.expectStopped();
-
-        // Reconnect.
-        myNet = connectKeepaliveNetwork(lp);
-        mWiFiNetworkAgent.setStartKeepaliveError(PacketKeepalive.SUCCESS);
-
         // Check that keepalive slots start from 1 and increment. The first one gets slot 1.
         mWiFiNetworkAgent.setExpectedKeepaliveSlot(1);
         ka = mCm.startNattKeepalive(myNet, validKaInterval, callback, myIPv4, 12345, dstIPv4);
@@ -3992,17 +4018,24 @@
         callback3.expectStopped();
     }
 
-    @Test
-    public void testNattSocketKeepalives_SingleThreadExecutor() throws Exception {
+    // Helper method to prepare the executor and run test
+    private void runTestWithSerialExecutors(Consumer<Executor> functor) {
         final ExecutorService executorSingleThread = Executors.newSingleThreadExecutor();
-        doTestNattSocketKeepalivesWithExecutor(executorSingleThread);
+        final Executor executorInline = (Runnable r) -> r.run();
+        functor.accept(executorSingleThread);
         executorSingleThread.shutdown();
+        functor.accept(executorInline);
     }
 
     @Test
-    public void testNattSocketKeepalives_InlineExecutor() throws Exception {
-        final Executor executorInline = (Runnable r) -> r.run();
-        doTestNattSocketKeepalivesWithExecutor(executorInline);
+    public void testNattSocketKeepalives() {
+        runTestWithSerialExecutors(executor -> {
+            try {
+                doTestNattSocketKeepalivesWithExecutor(executor);
+            } catch (Exception e) {
+                fail(e.getMessage());
+            }
+        });
     }
 
     private void doTestNattSocketKeepalivesWithExecutor(Executor executor) throws Exception {
@@ -4031,7 +4064,7 @@
         Network notMyNet = new Network(61234);
         Network myNet = connectKeepaliveNetwork(lp);
 
-        TestSocketKeepaliveCallback callback = new TestSocketKeepaliveCallback();
+        TestSocketKeepaliveCallback callback = new TestSocketKeepaliveCallback(executor);
         SocketKeepalive ka;
 
         // Attempt to start keepalives with invalid parameters and check for errors.
@@ -4074,6 +4107,22 @@
         ka.stop();
         callback.expectStopped();
 
+        // Check that keepalive could be restarted.
+        ka.start(validKaInterval);
+        callback.expectStarted();
+        ka.stop();
+        callback.expectStopped();
+
+        // Check that keepalive can be restarted without waiting for callback.
+        ka.start(validKaInterval);
+        callback.expectStarted();
+        ka.stop();
+        ka.start(validKaInterval);
+        callback.expectStopped();
+        callback.expectStarted();
+        ka.stop();
+        callback.expectStopped();
+
         // Check that deleting the IP address stops the keepalive.
         LinkProperties bogusLp = new LinkProperties(lp);
         ka = mCm.createSocketKeepalive(myNet, testSocket, myIPv4, dstIPv4, executor, callback);
@@ -4098,20 +4147,7 @@
         final Network myNetAlias = myNet;
         assertNull(mCm.getNetworkCapabilities(myNetAlias));
         ka.stop();
-
-        // Reconnect.
-        myNet = connectKeepaliveNetwork(lp);
-        mWiFiNetworkAgent.setStartKeepaliveError(SocketKeepalive.SUCCESS);
-
-        // Check things work as expected when the keepalive is stopped and the network disconnects.
-        ka = mCm.createSocketKeepalive(myNet, testSocket, myIPv4, dstIPv4, executor, callback);
-        ka.start(validKaInterval);
-        callback.expectStarted();
-        ka.stop();
-        mWiFiNetworkAgent.disconnect();
-        waitFor(mWiFiNetworkAgent.getDisconnectedCV());
-        waitForIdle();
-        callback.expectStopped();
+        callback.assertNoCallback();
 
         // Reconnect.
         myNet = connectKeepaliveNetwork(lp);
@@ -4126,7 +4162,7 @@
         // The second one gets slot 2.
         mWiFiNetworkAgent.setExpectedKeepaliveSlot(2);
         final UdpEncapsulationSocket testSocket2 = mIpSec.openUdpEncapsulationSocket(6789);
-        TestSocketKeepaliveCallback callback2 = new TestSocketKeepaliveCallback();
+        TestSocketKeepaliveCallback callback2 = new TestSocketKeepaliveCallback(executor);
         SocketKeepalive ka2 =
                 mCm.createSocketKeepalive(myNet, testSocket2, myIPv4, dstIPv4, executor, callback2);
         ka2.start(validKaInterval);
@@ -4143,6 +4179,81 @@
 
         mWiFiNetworkAgent.disconnect();
         waitFor(mWiFiNetworkAgent.getDisconnectedCV());
+        mWiFiNetworkAgent = null;
+    }
+
+    @Test
+    public void testTcpSocketKeepalives() {
+        runTestWithSerialExecutors(executor -> {
+            try {
+                doTestTcpSocketKeepalivesWithExecutor(executor);
+            } catch (Exception e) {
+                fail(e.getMessage());
+            }
+        });
+    }
+
+    private void doTestTcpSocketKeepalivesWithExecutor(Executor executor) throws Exception {
+        final int srcPortV4 = 12345;
+        final int srcPortV6 = 23456;
+        final InetAddress myIPv4 = InetAddress.getByName("127.0.0.1");
+        final InetAddress myIPv6 = InetAddress.getByName("::1");
+
+        final int validKaInterval = 15;
+        final int invalidKaInterval = 9;
+
+        final LinkProperties lp = new LinkProperties();
+        lp.setInterfaceName("wlan12");
+        lp.addLinkAddress(new LinkAddress(myIPv6, 64));
+        lp.addLinkAddress(new LinkAddress(myIPv4, 25));
+        lp.addRoute(new RouteInfo(InetAddress.getByName("fe80::1234")));
+        lp.addRoute(new RouteInfo(InetAddress.getByName("127.0.0.254")));
+
+        final Network notMyNet = new Network(61234);
+        final Network myNet = connectKeepaliveNetwork(lp);
+
+        final Socket testSocketV4 = new Socket();
+        final Socket testSocketV6 = new Socket();
+
+        TestSocketKeepaliveCallback callback = new TestSocketKeepaliveCallback(executor);
+        SocketKeepalive ka;
+
+        // Attempt to start Tcp keepalives with invalid parameters and check for errors.
+        // Invalid network.
+        ka = mCm.createSocketKeepalive(notMyNet, testSocketV4, executor, callback);
+        ka.start(validKaInterval);
+        callback.expectError(SocketKeepalive.ERROR_INVALID_NETWORK);
+
+        // Invalid Socket (socket is not bound with IPv4 address).
+        ka = mCm.createSocketKeepalive(myNet, testSocketV4, executor, callback);
+        ka.start(validKaInterval);
+        callback.expectError(SocketKeepalive.ERROR_INVALID_SOCKET);
+
+        // Invalid Socket (socket is not bound with IPv6 address).
+        ka = mCm.createSocketKeepalive(myNet, testSocketV6, executor, callback);
+        ka.start(validKaInterval);
+        callback.expectError(SocketKeepalive.ERROR_INVALID_SOCKET);
+
+        // Bind the socket address
+        testSocketV4.bind(new InetSocketAddress(myIPv4, srcPortV4));
+        testSocketV6.bind(new InetSocketAddress(myIPv6, srcPortV6));
+
+        // Invalid Socket (socket is bound with IPv4 address).
+        ka = mCm.createSocketKeepalive(myNet, testSocketV4, executor, callback);
+        ka.start(validKaInterval);
+        callback.expectError(SocketKeepalive.ERROR_INVALID_SOCKET);
+
+        // Invalid Socket (socket is bound with IPv6 address).
+        ka = mCm.createSocketKeepalive(myNet, testSocketV6, executor, callback);
+        ka.start(validKaInterval);
+        callback.expectError(SocketKeepalive.ERROR_INVALID_SOCKET);
+
+        testSocketV4.close();
+        testSocketV6.close();
+
+        mWiFiNetworkAgent.disconnect();
+        waitFor(mWiFiNetworkAgent.getDisconnectedCV());
+        mWiFiNetworkAgent = null;
     }
 
     @Test
diff --git a/tests/net/java/com/android/server/connectivity/LingerMonitorTest.java b/tests/net/java/com/android/server/connectivity/LingerMonitorTest.java
index 38352b3..6de4aa1 100644
--- a/tests/net/java/com/android/server/connectivity/LingerMonitorTest.java
+++ b/tests/net/java/com/android/server/connectivity/LingerMonitorTest.java
@@ -35,6 +35,7 @@
 import android.net.INetd;
 import android.net.Network;
 import android.net.NetworkCapabilities;
+import android.net.NetworkFactory;
 import android.net.NetworkInfo;
 import android.net.NetworkMisc;
 import android.os.INetworkManagementService;
@@ -352,7 +353,8 @@
         caps.addCapability(0);
         caps.addTransportType(transport);
         NetworkAgentInfo nai = new NetworkAgentInfo(null, null, new Network(netId), info, null,
-                caps, 50, mCtx, null, mMisc, mConnService, mNetd, mNMS);
+                caps, 50, mCtx, null, mMisc, mConnService, mNetd, mNMS,
+                NetworkFactory.SerialNumber.NONE);
         nai.everValidated = true;
         return nai;
     }