Merge "Stop TCP keepalive from CS for fd initiated stop events"
diff --git a/service/src/com/android/server/connectivity/KeepaliveTracker.java b/service/src/com/android/server/connectivity/KeepaliveTracker.java
index 941b616..1fd8a62 100644
--- a/service/src/com/android/server/connectivity/KeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveTracker.java
@@ -106,7 +106,12 @@
     private final int mAllowedUnprivilegedSlotsForUid;
 
     public KeepaliveTracker(Context context, Handler handler) {
-        mTcpController = new TcpKeepaliveController(handler);
+        this(context, handler, new TcpKeepaliveController(handler));
+    }
+
+    @VisibleForTesting
+    KeepaliveTracker(Context context, Handler handler, TcpKeepaliveController tcpController) {
+        mTcpController = tcpController;
         mContext = context;
 
         mSupportedKeepalives = KeepaliveResourceUtil.getSupportedKeepalives(context);
@@ -375,7 +380,7 @@
                         break;
                     case TYPE_TCP:
                         try {
-                            mTcpController.startSocketMonitor(mFd, this, mSlot);
+                            mTcpController.startSocketMonitor(mFd, mCallback, mSlot);
                         } catch (InvalidSocketException e) {
                             handleStopKeepalive(mNai, mSlot, ERROR_INVALID_SOCKET);
                             return ERROR_INVALID_SOCKET;
@@ -455,12 +460,6 @@
             }
         }
 
-        // TODO: This does not clean up the autoKi in AutomaticOnOffKeepaliveTracker and it is not
-        // possible without a big refactor.
-        void onFileDescriptorInitiatedStop(final int socketKeepaliveReason) {
-            handleStopKeepalive(mNai, mSlot, socketKeepaliveReason);
-        }
-
         /**
          * Construct a new KeepaliveInfo from existing KeepaliveInfo with a new fd.
          */
diff --git a/service/src/com/android/server/connectivity/TcpKeepaliveController.java b/service/src/com/android/server/connectivity/TcpKeepaliveController.java
index a9cb2fa..0fd8604 100644
--- a/service/src/com/android/server/connectivity/TcpKeepaliveController.java
+++ b/service/src/com/android/server/connectivity/TcpKeepaliveController.java
@@ -15,6 +15,7 @@
  */
 package com.android.server.connectivity;
 
+import static android.net.NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE;
 import static android.net.SocketKeepalive.DATA_RECEIVED;
 import static android.net.SocketKeepalive.ERROR_INVALID_IP_ADDRESS;
 import static android.net.SocketKeepalive.ERROR_INVALID_SOCKET;
@@ -33,6 +34,7 @@
 import static com.android.net.module.util.NetworkStackConstants.IPV4_HEADER_MIN_LEN;
 
 import android.annotation.NonNull;
+import android.net.ISocketKeepaliveCallback;
 import android.net.InvalidPacketException;
 import android.net.NetworkUtils;
 import android.net.SocketKeepalive.InvalidSocketException;
@@ -50,7 +52,6 @@
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.net.module.util.IpUtils;
-import com.android.server.connectivity.KeepaliveTracker.KeepaliveInfo;
 
 import java.io.FileDescriptor;
 import java.net.InetAddress;
@@ -88,6 +89,8 @@
 
     private final MessageQueue mFdHandlerQueue;
 
+    private final Handler mConnectivityServiceHandler;
+
     private static final int FD_EVENTS = EVENT_INPUT | EVENT_ERROR;
 
     private static final int TCP_HEADER_LENGTH = 20;
@@ -115,6 +118,7 @@
 
     public TcpKeepaliveController(final Handler connectivityServiceHandler) {
         mFdHandlerQueue = connectivityServiceHandler.getLooper().getQueue();
+        mConnectivityServiceHandler = connectivityServiceHandler;
     }
 
     /** Build tcp keepalive packet. */
@@ -324,12 +328,13 @@
      * Start monitoring incoming packets.
      *
      * @param fd socket fd to monitor.
-     * @param ki a {@link KeepaliveInfo} that tracks information about a socket keepalive.
+     * @param callback a {@link ISocketKeepaliveCallback} that tracks information about a socket
+     *                 keepalive.
      * @param slot keepalive slot.
      */
-    public void startSocketMonitor(@NonNull final FileDescriptor fd,
-            @NonNull final KeepaliveInfo ki, final int slot)
-            throws IllegalArgumentException, InvalidSocketException {
+    public void startSocketMonitor(
+            @NonNull final FileDescriptor fd, @NonNull final ISocketKeepaliveCallback callback,
+            final int slot) throws IllegalArgumentException, InvalidSocketException {
         synchronized (mListeners) {
             if (null != mListeners.get(slot)) {
                 throw new IllegalArgumentException("This slot is already taken");
@@ -350,7 +355,8 @@
                 } else {
                     reason = DATA_RECEIVED;
                 }
-                ki.onFileDescriptorInitiatedStop(reason);
+                mConnectivityServiceHandler.obtainMessage(CMD_STOP_SOCKET_KEEPALIVE,
+                        0 /* unused */, reason, callback.asBinder()).sendToTarget();
                 // The listener returns the new set of events to listen to. Because 0 means no
                 // event, the listener gets unregistered.
                 return 0;
diff --git a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
index 608e6d8..db65c2b 100644
--- a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
@@ -18,6 +18,7 @@
 
 import static android.content.pm.PackageManager.PERMISSION_GRANTED;
 import static android.net.ConnectivityManager.TYPE_MOBILE;
+import static android.net.NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE;
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 
 import static com.android.testutils.HandlerUtils.visibleOnHandlerThread;
@@ -28,6 +29,7 @@
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.anyInt;
@@ -47,6 +49,7 @@
 import android.content.res.Resources;
 import android.net.INetd;
 import android.net.ISocketKeepaliveCallback;
+import android.net.KeepalivePacketData;
 import android.net.LinkAddress;
 import android.net.LinkProperties;
 import android.net.MarkMaskParcel;
@@ -55,6 +58,7 @@
 import android.net.NetworkCapabilities;
 import android.net.NetworkInfo;
 import android.net.SocketKeepalive;
+import android.net.TcpKeepalivePacketData;
 import android.os.Binder;
 import android.os.Build;
 import android.os.Handler;
@@ -120,6 +124,7 @@
 
     TestKeepaliveTracker mKeepaliveTracker;
     AOOTestHandler mTestHandler;
+    TestTcpKeepaliveController mTcpController;
 
     // Hexadecimal representation of a SOCK_DIAG response with tcp info.
     private static final String SOCK_DIAG_TCP_INET_HEX =
@@ -233,9 +238,9 @@
         public final FileDescriptor fd;
         public final ISocketKeepaliveCallback socketKeepaliveCallback;
         public final Network underpinnedNetwork;
-        public final NattKeepalivePacketData kpd;
+        public final KeepalivePacketData kpd;
 
-        TestKeepaliveInfo(NattKeepalivePacketData kpd) throws Exception {
+        TestKeepaliveInfo(KeepalivePacketData kpd) throws Exception {
             this.kpd = kpd;
             socket = new Socket();
             socket.bind(null);
@@ -252,8 +257,9 @@
     private class TestKeepaliveTracker extends KeepaliveTracker {
         private KeepaliveInfo mKi;
 
-        TestKeepaliveTracker(@NonNull final Context context, @NonNull final Handler handler) {
-            super(context, handler);
+        TestKeepaliveTracker(@NonNull final Context context, @NonNull final Handler handler,
+                @NonNull final TcpKeepaliveController tcpController) {
+            super(context, handler, tcpController);
         }
 
         public void setReturnedKeepaliveInfo(@NonNull final KeepaliveInfo ki) {
@@ -272,6 +278,24 @@
             }
             return mKi;
         }
+
+        @NonNull
+        @Override
+        public KeepaliveInfo makeTcpKeepaliveInfo(@Nullable final NetworkAgentInfo nai,
+                @Nullable final FileDescriptor fd, final int intervalSeconds,
+                @NonNull final ISocketKeepaliveCallback cb) {
+            if (null == mKi) {
+                throw new IllegalStateException("Please call `setReturnedKeepaliveInfo`"
+                        + " before makeTcpKeepaliveInfo is called");
+            }
+            return mKi;
+        }
+    }
+
+    private static class TestTcpKeepaliveController extends TcpKeepaliveController {
+        TestTcpKeepaliveController(final Handler connectivityServiceHandler) {
+            super(connectivityServiceHandler);
+        }
     }
 
     @Before
@@ -303,7 +327,8 @@
         mHandlerThread = new HandlerThread("KeepaliveTrackerTest");
         mHandlerThread.start();
         mTestHandler = new AOOTestHandler(mHandlerThread.getLooper());
-        mKeepaliveTracker = new TestKeepaliveTracker(mCtx, mTestHandler);
+        mTcpController = new TestTcpKeepaliveController(mTestHandler);
+        mKeepaliveTracker = new TestKeepaliveTracker(mCtx, mTestHandler, mTcpController);
         doReturn(mKeepaliveTracker).when(mDependencies).newKeepaliveTracker(mCtx, mTestHandler);
         doReturn(true).when(mDependencies).isFeatureEnabled(any(), anyBoolean());
         mAOOKeepaliveTracker =
@@ -333,6 +358,14 @@
                     Log.d(TAG, "Test handler received CMD_MONITOR_AUTOMATIC_KEEPALIVE : " + msg);
                     mLastAutoKi = mAOOKeepaliveTracker.getKeepaliveForBinder((IBinder) msg.obj);
                     break;
+                case CMD_STOP_SOCKET_KEEPALIVE:
+                    Log.d(TAG, "Test handler received CMD_STOP_SOCKET_KEEPALIVE : " + msg);
+                    mLastAutoKi = mAOOKeepaliveTracker.getKeepaliveForBinder((IBinder) msg.obj);
+                    if (mLastAutoKi == null) {
+                        fail("Attempt to stop an already stopped keepalive");
+                    }
+                    mAOOKeepaliveTracker.handleStopKeepalive(mLastAutoKi, msg.arg2);
+                    break;
             }
         }
     }
@@ -481,14 +514,15 @@
                 mTestHandler, () -> mAOOKeepaliveTracker.getKeepaliveForBinder(binder));
     }
 
-    private void checkAndProcessKeepaliveStart(final NattKeepalivePacketData kpd) throws Exception {
+    private void checkAndProcessKeepaliveStart(final KeepalivePacketData kpd) throws Exception {
         checkAndProcessKeepaliveStart(TEST_SLOT, kpd);
     }
 
     private void checkAndProcessKeepaliveStart(
-            int slot, final NattKeepalivePacketData kpd) throws Exception {
-        verify(mNai).onStartNattSocketKeepalive(slot, TEST_KEEPALIVE_INTERVAL_SEC, kpd);
-        verify(mNai).onAddNattKeepalivePacketFilter(slot, kpd);
+            int slot, final KeepalivePacketData kpd) throws Exception {
+        verify(mNai).onStartNattSocketKeepalive(
+                slot, TEST_KEEPALIVE_INTERVAL_SEC, (NattKeepalivePacketData) kpd);
+        verify(mNai).onAddNattKeepalivePacketFilter(slot, (NattKeepalivePacketData) kpd);
         triggerEventKeepalive(slot, SocketKeepalive.SUCCESS);
     }
 
@@ -531,9 +565,10 @@
     public void testHandleEventSocketKeepalive_startingFailureHardwareError() throws Exception {
         final TestKeepaliveInfo testInfo = doStartNattKeepalive();
 
-        verify(mNai)
-                .onStartNattSocketKeepalive(TEST_SLOT, TEST_KEEPALIVE_INTERVAL_SEC, testInfo.kpd);
-        verify(mNai).onAddNattKeepalivePacketFilter(TEST_SLOT, testInfo.kpd);
+        verify(mNai).onStartNattSocketKeepalive(
+                TEST_SLOT, TEST_KEEPALIVE_INTERVAL_SEC, (NattKeepalivePacketData) testInfo.kpd);
+        verify(mNai).onAddNattKeepalivePacketFilter(
+                TEST_SLOT, (NattKeepalivePacketData) testInfo.kpd);
         // Network agent returns an error, fails to start the keepalive.
         triggerEventKeepalive(TEST_SLOT, SocketKeepalive.ERROR_HARDWARE_ERROR);
 
@@ -674,9 +709,10 @@
         clearInvocations(mNai);
         doResumeKeepalive(getAutoKiForBinder(testInfo.binder));
 
-        verify(mNai)
-                .onStartNattSocketKeepalive(TEST_SLOT, TEST_KEEPALIVE_INTERVAL_SEC, testInfo.kpd);
-        verify(mNai).onAddNattKeepalivePacketFilter(TEST_SLOT, testInfo.kpd);
+        verify(mNai).onStartNattSocketKeepalive(
+                TEST_SLOT, TEST_KEEPALIVE_INTERVAL_SEC, (NattKeepalivePacketData) testInfo.kpd);
+        verify(mNai).onAddNattKeepalivePacketFilter(
+                TEST_SLOT, (NattKeepalivePacketData) testInfo.kpd);
         // Network agent returns error on starting the keepalive.
         triggerEventKeepalive(TEST_SLOT, SocketKeepalive.ERROR_HARDWARE_ERROR);
 
@@ -772,4 +808,38 @@
         verifyNoMoreInteractions(ignoreStubs(testInfo1.socketKeepaliveCallback));
         verifyNoMoreInteractions(ignoreStubs(testInfo2.socketKeepaliveCallback));
     }
+
+    @Test
+    public void testStartTcpKeepalive_fdInitiatedStop() throws Exception {
+        final InetAddress srcAddress = InetAddress.getByAddress(
+                new byte[] { (byte) 192, 0, 0, (byte) 129 });
+        mNai.linkProperties.addLinkAddress(new LinkAddress(srcAddress, 24));
+
+        final KeepalivePacketData kpd = new TcpKeepalivePacketData(
+                InetAddress.getByAddress(new byte[] { (byte) 192, 0, 0, (byte) 129 }) /* srcAddr */,
+                12345 /* srcPort */,
+                InetAddress.getByAddress(new byte[] { 8, 8, 8, 8}) /* dstAddr */,
+                12345 /* dstPort */, new byte[] {1},  111 /* tcpSeq */,
+                222 /* tcpAck */, 800 /* tcpWindow */, 2 /* tcpWindowScale */,
+                4 /* ipTos */, 64 /* ipTtl */);
+        final TestKeepaliveInfo testInfo = new TestKeepaliveInfo(kpd);
+
+        final KeepaliveInfo ki = mKeepaliveTracker.new KeepaliveInfo(
+                testInfo.socketKeepaliveCallback, mNai, kpd,
+                TEST_KEEPALIVE_INTERVAL_SEC, KeepaliveInfo.TYPE_TCP, testInfo.fd);
+        mKeepaliveTracker.setReturnedKeepaliveInfo(ki);
+
+        // Setup TCP keepalive.
+        mAOOKeepaliveTracker.startTcpKeepalive(mNai, testInfo.fd, TEST_KEEPALIVE_INTERVAL_SEC,
+                testInfo.socketKeepaliveCallback);
+        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+
+        // A closed socket will result in EVENT_HANGUP and trigger error to
+        // FileDescriptorEventListener.
+        testInfo.socket.close();
+        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+
+        // The keepalive should be removed in AutomaticOnOffKeepaliveTracker.
+        getAutoKiForBinder(testInfo.binder);
+    }
 }