Stop TCP keepalive from CS for fd initiated stop events

When a TCP keepalive start request is sent from ConnectivityService
to KeepaliveTracker, the keepalive is tracked on both
AutomaticOnOffKeepaliveTracker and KeepaliveTracker. The existing
design stops the keepalive inside KeepaliveTracker without
notifying AutomaticOnOffKeepaliveTracker for fd initiated stops.
This causes AutomaticOnOffKeepaliveTracker to lose track of these
fd initiated stops and cause a leak on the object.

The updated design sends the event to ConnectivityService handler
to handle the event down from ConnectivityService. This ensures
that each stakeholder class will get the stop event.

Bug: 283885097
Test: atest FrameworksNetTests
Test: atest CtsNetTestCases
Change-Id: I3c40d80694cd2c046f3a19ddb8f437878c98ab43
diff --git a/service/src/com/android/server/connectivity/KeepaliveTracker.java b/service/src/com/android/server/connectivity/KeepaliveTracker.java
index 941b616..1fd8a62 100644
--- a/service/src/com/android/server/connectivity/KeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveTracker.java
@@ -106,7 +106,12 @@
     private final int mAllowedUnprivilegedSlotsForUid;
 
     public KeepaliveTracker(Context context, Handler handler) {
-        mTcpController = new TcpKeepaliveController(handler);
+        this(context, handler, new TcpKeepaliveController(handler));
+    }
+
+    @VisibleForTesting
+    KeepaliveTracker(Context context, Handler handler, TcpKeepaliveController tcpController) {
+        mTcpController = tcpController;
         mContext = context;
 
         mSupportedKeepalives = KeepaliveResourceUtil.getSupportedKeepalives(context);
@@ -375,7 +380,7 @@
                         break;
                     case TYPE_TCP:
                         try {
-                            mTcpController.startSocketMonitor(mFd, this, mSlot);
+                            mTcpController.startSocketMonitor(mFd, mCallback, mSlot);
                         } catch (InvalidSocketException e) {
                             handleStopKeepalive(mNai, mSlot, ERROR_INVALID_SOCKET);
                             return ERROR_INVALID_SOCKET;
@@ -455,12 +460,6 @@
             }
         }
 
-        // TODO: This does not clean up the autoKi in AutomaticOnOffKeepaliveTracker and it is not
-        // possible without a big refactor.
-        void onFileDescriptorInitiatedStop(final int socketKeepaliveReason) {
-            handleStopKeepalive(mNai, mSlot, socketKeepaliveReason);
-        }
-
         /**
          * Construct a new KeepaliveInfo from existing KeepaliveInfo with a new fd.
          */
diff --git a/service/src/com/android/server/connectivity/TcpKeepaliveController.java b/service/src/com/android/server/connectivity/TcpKeepaliveController.java
index a9cb2fa..0fd8604 100644
--- a/service/src/com/android/server/connectivity/TcpKeepaliveController.java
+++ b/service/src/com/android/server/connectivity/TcpKeepaliveController.java
@@ -15,6 +15,7 @@
  */
 package com.android.server.connectivity;
 
+import static android.net.NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE;
 import static android.net.SocketKeepalive.DATA_RECEIVED;
 import static android.net.SocketKeepalive.ERROR_INVALID_IP_ADDRESS;
 import static android.net.SocketKeepalive.ERROR_INVALID_SOCKET;
@@ -33,6 +34,7 @@
 import static com.android.net.module.util.NetworkStackConstants.IPV4_HEADER_MIN_LEN;
 
 import android.annotation.NonNull;
+import android.net.ISocketKeepaliveCallback;
 import android.net.InvalidPacketException;
 import android.net.NetworkUtils;
 import android.net.SocketKeepalive.InvalidSocketException;
@@ -50,7 +52,6 @@
 import com.android.internal.annotations.GuardedBy;
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.net.module.util.IpUtils;
-import com.android.server.connectivity.KeepaliveTracker.KeepaliveInfo;
 
 import java.io.FileDescriptor;
 import java.net.InetAddress;
@@ -88,6 +89,8 @@
 
     private final MessageQueue mFdHandlerQueue;
 
+    private final Handler mConnectivityServiceHandler;
+
     private static final int FD_EVENTS = EVENT_INPUT | EVENT_ERROR;
 
     private static final int TCP_HEADER_LENGTH = 20;
@@ -115,6 +118,7 @@
 
     public TcpKeepaliveController(final Handler connectivityServiceHandler) {
         mFdHandlerQueue = connectivityServiceHandler.getLooper().getQueue();
+        mConnectivityServiceHandler = connectivityServiceHandler;
     }
 
     /** Build tcp keepalive packet. */
@@ -324,12 +328,13 @@
      * Start monitoring incoming packets.
      *
      * @param fd socket fd to monitor.
-     * @param ki a {@link KeepaliveInfo} that tracks information about a socket keepalive.
+     * @param callback a {@link ISocketKeepaliveCallback} that tracks information about a socket
+     *                 keepalive.
      * @param slot keepalive slot.
      */
-    public void startSocketMonitor(@NonNull final FileDescriptor fd,
-            @NonNull final KeepaliveInfo ki, final int slot)
-            throws IllegalArgumentException, InvalidSocketException {
+    public void startSocketMonitor(
+            @NonNull final FileDescriptor fd, @NonNull final ISocketKeepaliveCallback callback,
+            final int slot) throws IllegalArgumentException, InvalidSocketException {
         synchronized (mListeners) {
             if (null != mListeners.get(slot)) {
                 throw new IllegalArgumentException("This slot is already taken");
@@ -350,7 +355,8 @@
                 } else {
                     reason = DATA_RECEIVED;
                 }
-                ki.onFileDescriptorInitiatedStop(reason);
+                mConnectivityServiceHandler.obtainMessage(CMD_STOP_SOCKET_KEEPALIVE,
+                        0 /* unused */, reason, callback.asBinder()).sendToTarget();
                 // The listener returns the new set of events to listen to. Because 0 means no
                 // event, the listener gets unregistered.
                 return 0;
diff --git a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
index 608e6d8..db65c2b 100644
--- a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
@@ -18,6 +18,7 @@
 
 import static android.content.pm.PackageManager.PERMISSION_GRANTED;
 import static android.net.ConnectivityManager.TYPE_MOBILE;
+import static android.net.NetworkAgent.CMD_STOP_SOCKET_KEEPALIVE;
 import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
 
 import static com.android.testutils.HandlerUtils.visibleOnHandlerThread;
@@ -28,6 +29,7 @@
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.ArgumentMatchers.anyInt;
@@ -47,6 +49,7 @@
 import android.content.res.Resources;
 import android.net.INetd;
 import android.net.ISocketKeepaliveCallback;
+import android.net.KeepalivePacketData;
 import android.net.LinkAddress;
 import android.net.LinkProperties;
 import android.net.MarkMaskParcel;
@@ -55,6 +58,7 @@
 import android.net.NetworkCapabilities;
 import android.net.NetworkInfo;
 import android.net.SocketKeepalive;
+import android.net.TcpKeepalivePacketData;
 import android.os.Binder;
 import android.os.Build;
 import android.os.Handler;
@@ -120,6 +124,7 @@
 
     TestKeepaliveTracker mKeepaliveTracker;
     AOOTestHandler mTestHandler;
+    TestTcpKeepaliveController mTcpController;
 
     // Hexadecimal representation of a SOCK_DIAG response with tcp info.
     private static final String SOCK_DIAG_TCP_INET_HEX =
@@ -233,9 +238,9 @@
         public final FileDescriptor fd;
         public final ISocketKeepaliveCallback socketKeepaliveCallback;
         public final Network underpinnedNetwork;
-        public final NattKeepalivePacketData kpd;
+        public final KeepalivePacketData kpd;
 
-        TestKeepaliveInfo(NattKeepalivePacketData kpd) throws Exception {
+        TestKeepaliveInfo(KeepalivePacketData kpd) throws Exception {
             this.kpd = kpd;
             socket = new Socket();
             socket.bind(null);
@@ -252,8 +257,9 @@
     private class TestKeepaliveTracker extends KeepaliveTracker {
         private KeepaliveInfo mKi;
 
-        TestKeepaliveTracker(@NonNull final Context context, @NonNull final Handler handler) {
-            super(context, handler);
+        TestKeepaliveTracker(@NonNull final Context context, @NonNull final Handler handler,
+                @NonNull final TcpKeepaliveController tcpController) {
+            super(context, handler, tcpController);
         }
 
         public void setReturnedKeepaliveInfo(@NonNull final KeepaliveInfo ki) {
@@ -272,6 +278,24 @@
             }
             return mKi;
         }
+
+        @NonNull
+        @Override
+        public KeepaliveInfo makeTcpKeepaliveInfo(@Nullable final NetworkAgentInfo nai,
+                @Nullable final FileDescriptor fd, final int intervalSeconds,
+                @NonNull final ISocketKeepaliveCallback cb) {
+            if (null == mKi) {
+                throw new IllegalStateException("Please call `setReturnedKeepaliveInfo`"
+                        + " before makeTcpKeepaliveInfo is called");
+            }
+            return mKi;
+        }
+    }
+
+    private static class TestTcpKeepaliveController extends TcpKeepaliveController {
+        TestTcpKeepaliveController(final Handler connectivityServiceHandler) {
+            super(connectivityServiceHandler);
+        }
     }
 
     @Before
@@ -303,7 +327,8 @@
         mHandlerThread = new HandlerThread("KeepaliveTrackerTest");
         mHandlerThread.start();
         mTestHandler = new AOOTestHandler(mHandlerThread.getLooper());
-        mKeepaliveTracker = new TestKeepaliveTracker(mCtx, mTestHandler);
+        mTcpController = new TestTcpKeepaliveController(mTestHandler);
+        mKeepaliveTracker = new TestKeepaliveTracker(mCtx, mTestHandler, mTcpController);
         doReturn(mKeepaliveTracker).when(mDependencies).newKeepaliveTracker(mCtx, mTestHandler);
         doReturn(true).when(mDependencies).isFeatureEnabled(any(), anyBoolean());
         mAOOKeepaliveTracker =
@@ -333,6 +358,14 @@
                     Log.d(TAG, "Test handler received CMD_MONITOR_AUTOMATIC_KEEPALIVE : " + msg);
                     mLastAutoKi = mAOOKeepaliveTracker.getKeepaliveForBinder((IBinder) msg.obj);
                     break;
+                case CMD_STOP_SOCKET_KEEPALIVE:
+                    Log.d(TAG, "Test handler received CMD_STOP_SOCKET_KEEPALIVE : " + msg);
+                    mLastAutoKi = mAOOKeepaliveTracker.getKeepaliveForBinder((IBinder) msg.obj);
+                    if (mLastAutoKi == null) {
+                        fail("Attempt to stop an already stopped keepalive");
+                    }
+                    mAOOKeepaliveTracker.handleStopKeepalive(mLastAutoKi, msg.arg2);
+                    break;
             }
         }
     }
@@ -481,14 +514,15 @@
                 mTestHandler, () -> mAOOKeepaliveTracker.getKeepaliveForBinder(binder));
     }
 
-    private void checkAndProcessKeepaliveStart(final NattKeepalivePacketData kpd) throws Exception {
+    private void checkAndProcessKeepaliveStart(final KeepalivePacketData kpd) throws Exception {
         checkAndProcessKeepaliveStart(TEST_SLOT, kpd);
     }
 
     private void checkAndProcessKeepaliveStart(
-            int slot, final NattKeepalivePacketData kpd) throws Exception {
-        verify(mNai).onStartNattSocketKeepalive(slot, TEST_KEEPALIVE_INTERVAL_SEC, kpd);
-        verify(mNai).onAddNattKeepalivePacketFilter(slot, kpd);
+            int slot, final KeepalivePacketData kpd) throws Exception {
+        verify(mNai).onStartNattSocketKeepalive(
+                slot, TEST_KEEPALIVE_INTERVAL_SEC, (NattKeepalivePacketData) kpd);
+        verify(mNai).onAddNattKeepalivePacketFilter(slot, (NattKeepalivePacketData) kpd);
         triggerEventKeepalive(slot, SocketKeepalive.SUCCESS);
     }
 
@@ -531,9 +565,10 @@
     public void testHandleEventSocketKeepalive_startingFailureHardwareError() throws Exception {
         final TestKeepaliveInfo testInfo = doStartNattKeepalive();
 
-        verify(mNai)
-                .onStartNattSocketKeepalive(TEST_SLOT, TEST_KEEPALIVE_INTERVAL_SEC, testInfo.kpd);
-        verify(mNai).onAddNattKeepalivePacketFilter(TEST_SLOT, testInfo.kpd);
+        verify(mNai).onStartNattSocketKeepalive(
+                TEST_SLOT, TEST_KEEPALIVE_INTERVAL_SEC, (NattKeepalivePacketData) testInfo.kpd);
+        verify(mNai).onAddNattKeepalivePacketFilter(
+                TEST_SLOT, (NattKeepalivePacketData) testInfo.kpd);
         // Network agent returns an error, fails to start the keepalive.
         triggerEventKeepalive(TEST_SLOT, SocketKeepalive.ERROR_HARDWARE_ERROR);
 
@@ -674,9 +709,10 @@
         clearInvocations(mNai);
         doResumeKeepalive(getAutoKiForBinder(testInfo.binder));
 
-        verify(mNai)
-                .onStartNattSocketKeepalive(TEST_SLOT, TEST_KEEPALIVE_INTERVAL_SEC, testInfo.kpd);
-        verify(mNai).onAddNattKeepalivePacketFilter(TEST_SLOT, testInfo.kpd);
+        verify(mNai).onStartNattSocketKeepalive(
+                TEST_SLOT, TEST_KEEPALIVE_INTERVAL_SEC, (NattKeepalivePacketData) testInfo.kpd);
+        verify(mNai).onAddNattKeepalivePacketFilter(
+                TEST_SLOT, (NattKeepalivePacketData) testInfo.kpd);
         // Network agent returns error on starting the keepalive.
         triggerEventKeepalive(TEST_SLOT, SocketKeepalive.ERROR_HARDWARE_ERROR);
 
@@ -772,4 +808,38 @@
         verifyNoMoreInteractions(ignoreStubs(testInfo1.socketKeepaliveCallback));
         verifyNoMoreInteractions(ignoreStubs(testInfo2.socketKeepaliveCallback));
     }
+
+    @Test
+    public void testStartTcpKeepalive_fdInitiatedStop() throws Exception {
+        final InetAddress srcAddress = InetAddress.getByAddress(
+                new byte[] { (byte) 192, 0, 0, (byte) 129 });
+        mNai.linkProperties.addLinkAddress(new LinkAddress(srcAddress, 24));
+
+        final KeepalivePacketData kpd = new TcpKeepalivePacketData(
+                InetAddress.getByAddress(new byte[] { (byte) 192, 0, 0, (byte) 129 }) /* srcAddr */,
+                12345 /* srcPort */,
+                InetAddress.getByAddress(new byte[] { 8, 8, 8, 8}) /* dstAddr */,
+                12345 /* dstPort */, new byte[] {1},  111 /* tcpSeq */,
+                222 /* tcpAck */, 800 /* tcpWindow */, 2 /* tcpWindowScale */,
+                4 /* ipTos */, 64 /* ipTtl */);
+        final TestKeepaliveInfo testInfo = new TestKeepaliveInfo(kpd);
+
+        final KeepaliveInfo ki = mKeepaliveTracker.new KeepaliveInfo(
+                testInfo.socketKeepaliveCallback, mNai, kpd,
+                TEST_KEEPALIVE_INTERVAL_SEC, KeepaliveInfo.TYPE_TCP, testInfo.fd);
+        mKeepaliveTracker.setReturnedKeepaliveInfo(ki);
+
+        // Setup TCP keepalive.
+        mAOOKeepaliveTracker.startTcpKeepalive(mNai, testInfo.fd, TEST_KEEPALIVE_INTERVAL_SEC,
+                testInfo.socketKeepaliveCallback);
+        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+
+        // A closed socket will result in EVENT_HANGUP and trigger error to
+        // FileDescriptorEventListener.
+        testInfo.socket.close();
+        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+
+        // The keepalive should be removed in AutomaticOnOffKeepaliveTracker.
+        getAutoKiForBinder(testInfo.binder);
+    }
 }