Fix a bug where alarms are deduplicated

Test: new test in this patch
Bug: 269719647
Change-Id: I32e8d35f1751e2fd81a0007ae6cb308b6bcbb463
diff --git a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
index 76942ff..9f9b496 100644
--- a/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/AutomaticOnOffKeepaliveTracker.java
@@ -16,8 +16,6 @@
 
 package com.android.server.connectivity;
 
-import static android.Manifest.permission.NETWORK_STACK;
-import static android.content.Context.RECEIVER_NOT_EXPORTED;
 import static android.net.NetworkAgent.CMD_START_SOCKET_KEEPALIVE;
 import static android.net.SocketKeepalive.ERROR_INVALID_SOCKET;
 import static android.net.SocketKeepalive.SUCCESS_PAUSED;
@@ -36,18 +34,13 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.app.AlarmManager;
-import android.app.PendingIntent;
-import android.content.BroadcastReceiver;
 import android.content.Context;
-import android.content.Intent;
-import android.content.IntentFilter;
 import android.net.INetd;
 import android.net.ISocketKeepaliveCallback;
 import android.net.MarkMaskParcel;
 import android.net.Network;
 import android.net.NetworkAgent;
 import android.net.SocketKeepalive.InvalidSocketException;
-import android.os.Bundle;
 import android.os.FileUtils;
 import android.os.Handler;
 import android.os.IBinder;
@@ -63,7 +56,6 @@
 
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.IndentingPrintWriter;
-import com.android.modules.utils.build.SdkLevel;
 import com.android.net.module.util.BinderUtils;
 import com.android.net.module.util.CollectionUtils;
 import com.android.net.module.util.DeviceConfigUtils;
@@ -96,8 +88,6 @@
 public class AutomaticOnOffKeepaliveTracker {
     private static final String TAG = "AutomaticOnOffKeepaliveTracker";
     private static final int[] ADDRESS_FAMILIES = new int[] {AF_INET6, AF_INET};
-    private static final String ACTION_TCP_POLLING_ALARM =
-            "com.android.server.connectivity.KeepaliveTracker.TCP_POLLING_ALARM";
     private static final String EXTRA_BINDER_TOKEN = "token";
     private static final long DEFAULT_TCP_POLLING_INTERVAL_MS = 120_000L;
     private static final long LOW_TCP_POLLING_INTERVAL_MS = 1_000L;
@@ -162,19 +152,6 @@
     // TODO: Remove this when TCP polling design is replaced with callback.
     private long mTestLowTcpPollingTimerUntilMs = 0;
 
-    private final BroadcastReceiver mReceiver = new BroadcastReceiver() {
-        @Override
-        public void onReceive(Context context, Intent intent) {
-            if (ACTION_TCP_POLLING_ALARM.equals(intent.getAction())) {
-                Log.d(TAG, "Received TCP polling intent");
-                final IBinder token = intent.getBundleExtra(EXTRA_BINDER_TOKEN).getBinder(
-                        EXTRA_BINDER_TOKEN);
-                mConnectivityServiceHandler.obtainMessage(
-                        NetworkAgent.CMD_MONITOR_AUTOMATIC_KEEPALIVE, token).sendToTarget();
-            }
-        }
-    };
-
     /**
      * Information about a managed keepalive.
      *
@@ -195,7 +172,7 @@
         @Nullable
         private final FileDescriptor mFd;
         @Nullable
-        private final PendingIntent mTcpPollingAlarm;
+        private final AlarmManager.OnAlarmListener mAlarmListener;
         @AutomaticOnOffState
         private int mAutomaticOnOffState;
         @Nullable
@@ -221,17 +198,24 @@
                     Log.e(TAG, "Cannot dup fd: ", e);
                     throw new InvalidSocketException(ERROR_INVALID_SOCKET, e);
                 }
-                mTcpPollingAlarm = createTcpPollingAlarmIntent(context, mCallback.asBinder());
+                mAlarmListener = () -> mConnectivityServiceHandler.obtainMessage(
+                        NetworkAgent.CMD_MONITOR_AUTOMATIC_KEEPALIVE, mCallback.asBinder())
+                        .sendToTarget();
             } else {
                 mAutomaticOnOffState = STATE_ALWAYS_ON;
                 // A null fd is acceptable in KeepaliveInfo for backward compatibility of
                 // PacketKeepalive API, but it must never happen with automatic keepalives.
                 // TODO : remove mFd from KeepaliveInfo or from this class.
                 mFd = ki.mFd;
-                mTcpPollingAlarm = null;
+                mAlarmListener = null;
             }
         }
 
+        @VisibleForTesting
+        public ISocketKeepaliveCallback getCallback() {
+            return mCallback;
+        }
+
         public Network getNetwork() {
             return mKi.getNai().network;
         }
@@ -245,18 +229,6 @@
             return mKi.getNai().network().equals(network) && mKi.getSlot() == slot;
         }
 
-        private PendingIntent createTcpPollingAlarmIntent(@NonNull Context context,
-                @NonNull IBinder token) {
-            final Intent intent = new Intent(ACTION_TCP_POLLING_ALARM);
-            intent.setPackage(context.getPackageName());
-            // Intent doesn't expose methods to put extra Binders, but Bundle does.
-            final Bundle b = new Bundle();
-            b.putBinder(EXTRA_BINDER_TOKEN, token);
-            intent.putExtra(EXTRA_BINDER_TOKEN, b);
-            return BinderUtils.withCleanCallingIdentity(() -> PendingIntent.getBroadcast(
-                    context, 0 /* requestCode */, intent, PendingIntent.FLAG_IMMUTABLE));
-        }
-
         @Override
         public void binderDied() {
             mConnectivityServiceHandler.post(() -> cleanupAutoOnOffKeepalive(this));
@@ -306,18 +278,15 @@
         mKeepaliveTracker = mDependencies.newKeepaliveTracker(
                 mContext, mConnectivityServiceHandler);
 
-        if (SdkLevel.isAtLeastU()) {
-            mContext.registerReceiver(mReceiver, new IntentFilter(ACTION_TCP_POLLING_ALARM),
-                    NETWORK_STACK, handler, RECEIVER_NOT_EXPORTED);
-        }
-        mAlarmManager = mContext.getSystemService(AlarmManager.class);
+        mAlarmManager = mDependencies.getAlarmManager(context);
     }
 
-    private void startTcpPollingAlarm(@NonNull PendingIntent alarm) {
+    private void startTcpPollingAlarm(@NonNull final AlarmManager.OnAlarmListener listener) {
         final long triggerAtMillis =
                 SystemClock.elapsedRealtime() + getTcpPollingInterval();
         // Setup a non-wake up alarm.
-        mAlarmManager.setExact(AlarmManager.ELAPSED_REALTIME, triggerAtMillis, alarm);
+        mAlarmManager.setExact(AlarmManager.ELAPSED_REALTIME, triggerAtMillis, null /* tag */,
+                listener, mConnectivityServiceHandler);
     }
 
     /**
@@ -354,7 +323,7 @@
             handleMaybeResumeKeepalive(ki);
         }
         // TODO: listen to socket status instead of periodically check.
-        startTcpPollingAlarm(ki.mTcpPollingAlarm);
+        startTcpPollingAlarm(ki.mAlarmListener);
     }
 
     /**
@@ -434,7 +403,7 @@
         }
         mAutomaticOnOffKeepalives.add(autoKi);
         if (STATE_ALWAYS_ON != autoKi.mAutomaticOnOffState) {
-            startTcpPollingAlarm(autoKi.mTcpPollingAlarm);
+            startTcpPollingAlarm(autoKi.mAlarmListener);
         }
     }
 
@@ -466,7 +435,7 @@
     private void cleanupAutoOnOffKeepalive(@NonNull final AutomaticOnOffKeepalive autoKi) {
         ensureRunningOnHandlerThread();
         autoKi.close();
-        if (null != autoKi.mTcpPollingAlarm) mAlarmManager.cancel(autoKi.mTcpPollingAlarm);
+        if (null != autoKi.mAlarmListener) mAlarmManager.cancel(autoKi.mAlarmListener);
 
         // If the KI is not in the array, it's because it was already removed, or it was never
         // added ; the only ways this can happen is if the keepalive is stopped by the app and the
@@ -772,6 +741,13 @@
         }
 
         /**
+         * Get an instance of AlarmManager
+         */
+        public AlarmManager getAlarmManager(@NonNull final Context ctx) {
+            return ctx.getSystemService(AlarmManager.class);
+        }
+
+        /**
          * Receive the response message from kernel via given {@code FileDescriptor}.
          * The usage should follow the {@code #sendRequest} call with the same
          * FileDescriptor.
diff --git a/service/src/com/android/server/connectivity/KeepaliveTracker.java b/service/src/com/android/server/connectivity/KeepaliveTracker.java
index 5572361..06294db 100644
--- a/service/src/com/android/server/connectivity/KeepaliveTracker.java
+++ b/service/src/com/android/server/connectivity/KeepaliveTracker.java
@@ -58,6 +58,7 @@
 import android.util.Pair;
 
 import com.android.connectivity.resources.R;
+import com.android.internal.annotations.VisibleForTesting;
 import com.android.internal.util.IndentingPrintWriter;
 import com.android.net.module.util.HexDump;
 import com.android.net.module.util.IpUtils;
@@ -125,7 +126,8 @@
      * All information about this keepalive is known at construction time except the slot number,
      * which is only returned when the hardware has successfully started the keepalive.
      */
-    class KeepaliveInfo implements IBinder.DeathRecipient {
+    @VisibleForTesting
+    public class KeepaliveInfo implements IBinder.DeathRecipient {
         // TODO : remove this member. Only AutoOnOffKeepalive should have a reference to this.
         public final ISocketKeepaliveCallback mCallback;
         // Bookkeeping data.
diff --git a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
index ddf1d4d..4f0b9c4 100644
--- a/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
+++ b/tests/unit/java/com/android/server/connectivity/AutomaticOnOffKeepaliveTrackerTest.java
@@ -16,32 +16,70 @@
 
 package com.android.server.connectivity;
 
+import static android.content.pm.PackageManager.PERMISSION_GRANTED;
+import static android.net.ConnectivityManager.TYPE_MOBILE;
+import static android.net.NetworkCapabilities.TRANSPORT_CELLULAR;
+
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyBoolean;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.doNothing;
 import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 
+import android.app.AlarmManager;
 import android.content.Context;
+import android.content.res.Resources;
+import android.net.ConnectivityResources;
 import android.net.INetd;
+import android.net.ISocketKeepaliveCallback;
+import android.net.KeepalivePacketData;
+import android.net.LinkAddress;
+import android.net.LinkProperties;
 import android.net.MarkMaskParcel;
+import android.net.NattKeepalivePacketData;
+import android.net.Network;
+import android.net.NetworkAgent;
+import android.net.NetworkCapabilities;
+import android.net.NetworkInfo;
+import android.os.Binder;
 import android.os.Build;
+import android.os.Handler;
 import android.os.HandlerThread;
+import android.os.IBinder;
+import android.os.Looper;
+import android.os.Message;
 import android.test.suitebuilder.annotation.SmallTest;
+import android.util.Log;
 
+import androidx.annotation.NonNull;
+import androidx.annotation.Nullable;
+
+import com.android.server.connectivity.KeepaliveTracker.KeepaliveInfo;
 import com.android.testutils.DevSdkIgnoreRule;
 import com.android.testutils.DevSdkIgnoreRunner;
+import com.android.testutils.HandlerUtils;
 
 import libcore.util.HexEncoding;
 
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
+import java.io.FileDescriptor;
+import java.net.InetAddress;
+import java.net.Socket;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
 
@@ -49,17 +87,22 @@
 @SmallTest
 @DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.TIRAMISU)
 public class AutomaticOnOffKeepaliveTrackerTest {
+    private static final String TAG = AutomaticOnOffKeepaliveTrackerTest.class.getSimpleName();
     private static final int TEST_NETID = 0xA85;
     private static final int TEST_NETID_FWMARK = 0x0A85;
     private static final int OTHER_NETID = 0x1A85;
     private static final int NETID_MASK = 0xffff;
+    private static final int TIMEOUT_MS = 30_000;
+    private static final int MOCK_RESOURCE_ID = 5;
     private AutomaticOnOffKeepaliveTracker mAOOKeepaliveTracker;
     private HandlerThread mHandlerThread;
 
     @Mock INetd mNetd;
     @Mock AutomaticOnOffKeepaliveTracker.Dependencies mDependencies;
     @Mock Context mCtx;
-    @Mock KeepaliveTracker mKeepaliveTracker;
+    @Mock AlarmManager mAlarmManager;
+    TestKeepaliveTracker mKeepaliveTracker;
+    AOOTestHandler mTestHandler;
 
     // Hexadecimal representation of a SOCK_DIAG response with tcp info.
     private static final String SOCK_DIAG_TCP_INET_HEX =
@@ -158,11 +201,46 @@
     private static final byte[] TEST_RESPONSE_BYTES =
             HexEncoding.decode(TEST_RESPONSE_HEX.toCharArray(), false);
 
+    private class TestKeepaliveTracker extends KeepaliveTracker {
+        private KeepaliveInfo mKi;
+
+        TestKeepaliveTracker(@NonNull final Context context, @NonNull final Handler handler) {
+            super(context, handler);
+        }
+
+        public void setReturnedKeepaliveInfo(@NonNull final KeepaliveInfo ki) {
+            mKi = ki;
+        }
+
+        @NonNull
+        @Override
+        public KeepaliveInfo makeNattKeepaliveInfo(@Nullable final NetworkAgentInfo nai,
+                @Nullable final FileDescriptor fd, final int intervalSeconds,
+                @NonNull final ISocketKeepaliveCallback cb, @NonNull final String srcAddrString,
+                final int srcPort,
+                @NonNull final String dstAddrString, final int dstPort) {
+            if (null == mKi) {
+                throw new IllegalStateException("Must call setReturnedKeepaliveInfo");
+            }
+            return mKi;
+        }
+    }
+
     @Before
     public void setup() throws Exception {
         MockitoAnnotations.initMocks(this);
 
+        doReturn(PERMISSION_GRANTED).when(mCtx).checkPermission(any() /* permission */,
+                anyInt() /* pid */, anyInt() /* uid */);
+        ConnectivityResources.setResourcesContextForTest(mCtx);
+        final Resources mockResources = mock(Resources.class);
+        doReturn(MOCK_RESOURCE_ID).when(mockResources).getIdentifier(any() /* name */,
+                any() /* defType */, any() /* defPackage */);
+        doReturn(new String[] { "0,3", "3,3" }).when(mockResources)
+                .getStringArray(MOCK_RESOURCE_ID);
+        doReturn(mockResources).when(mCtx).getResources();
         doReturn(mNetd).when(mDependencies).getNetd();
+        doReturn(mAlarmManager).when(mDependencies).getAlarmManager(any());
         doReturn(makeMarkMaskParcel(NETID_MASK, TEST_NETID_FWMARK)).when(mNetd)
                 .getFwmarkForNetwork(TEST_NETID);
 
@@ -170,11 +248,34 @@
 
         mHandlerThread = new HandlerThread("KeepaliveTrackerTest");
         mHandlerThread.start();
-        doReturn(mKeepaliveTracker).when(mDependencies).newKeepaliveTracker(
-                mCtx, mHandlerThread.getThreadHandler());
+        mTestHandler = new AOOTestHandler(mHandlerThread.getLooper());
+        mKeepaliveTracker = new TestKeepaliveTracker(mCtx, mTestHandler);
+        doReturn(mKeepaliveTracker).when(mDependencies).newKeepaliveTracker(mCtx, mTestHandler);
         doReturn(true).when(mDependencies).isFeatureEnabled(any(), anyBoolean());
-        mAOOKeepaliveTracker = new AutomaticOnOffKeepaliveTracker(
-                mCtx, mHandlerThread.getThreadHandler(), mDependencies);
+        mAOOKeepaliveTracker =
+                new AutomaticOnOffKeepaliveTracker(mCtx, mTestHandler, mDependencies);
+    }
+
+    private final class AOOTestHandler extends Handler {
+        public AutomaticOnOffKeepaliveTracker.AutomaticOnOffKeepalive mLastAutoKi = null;
+
+        AOOTestHandler(@NonNull final Looper looper) {
+            super(looper);
+        }
+
+        @Override
+        public void handleMessage(@NonNull final Message msg) {
+            switch (msg.what) {
+                case NetworkAgent.CMD_START_SOCKET_KEEPALIVE:
+                    Log.d(TAG, "Test handler received CMD_START_SOCKET_KEEPALIVE : " + msg);
+                    mAOOKeepaliveTracker.handleStartKeepalive(msg);
+                    break;
+                case NetworkAgent.CMD_MONITOR_AUTOMATIC_KEEPALIVE:
+                    Log.d(TAG, "Test handler received CMD_MONITOR_AUTOMATIC_KEEPALIVE : " + msg);
+                    mLastAutoKi = mAOOKeepaliveTracker.getKeepaliveForBinder((IBinder) msg.obj);
+                    break;
+            }
+        }
     }
 
     @Test
@@ -187,24 +288,79 @@
     @Test
     public void testIsAnyTcpSocketConnected_withTargetNetId() throws Exception {
         setupResponseWithSocketExisting();
-        mHandlerThread.getThreadHandler().post(
+        mTestHandler.post(
                 () -> assertTrue(mAOOKeepaliveTracker.isAnyTcpSocketConnected(TEST_NETID)));
     }
 
     @Test
     public void testIsAnyTcpSocketConnected_withIncorrectNetId() throws Exception {
         setupResponseWithSocketExisting();
-        mHandlerThread.getThreadHandler().post(
+        mTestHandler.post(
                 () -> assertFalse(mAOOKeepaliveTracker.isAnyTcpSocketConnected(OTHER_NETID)));
     }
 
     @Test
     public void testIsAnyTcpSocketConnected_noSocketExists() throws Exception {
         setupResponseWithoutSocketExisting();
-        mHandlerThread.getThreadHandler().post(
+        mTestHandler.post(
                 () -> assertFalse(mAOOKeepaliveTracker.isAnyTcpSocketConnected(TEST_NETID)));
     }
 
+    @Test
+    public void testAlarm() throws Exception {
+        final InetAddress srcAddress = InetAddress.getByAddress(
+                new byte[] { (byte) 192, 0, 0, (byte) 129 });
+        final int srcPort = 12345;
+        final InetAddress dstAddress = InetAddress.getByAddress(new byte[] { 8, 8, 8, 8});
+        final int dstPort = 12345;
+
+        final NetworkAgentInfo nai = mock(NetworkAgentInfo.class);
+        nai.networkCapabilities = new NetworkCapabilities.Builder()
+                .addTransportType(TRANSPORT_CELLULAR).build();
+        nai.networkInfo = new NetworkInfo(TYPE_MOBILE, 0 /* subtype */, "LTE", "LTE");
+        nai.networkInfo.setDetailedState(NetworkInfo.DetailedState.CONNECTED, "test reason",
+                "test extra info");
+        nai.linkProperties = new LinkProperties();
+        nai.linkProperties.addLinkAddress(new LinkAddress(srcAddress, 24));
+
+        final Socket socket = new Socket();
+        socket.bind(null);
+        final FileDescriptor fd = socket.getFileDescriptor$();
+        final IBinder binder = new Binder();
+        final ISocketKeepaliveCallback cb = mock(ISocketKeepaliveCallback.class);
+        doReturn(binder).when(cb).asBinder();
+        final Network underpinnedNetwork = mock(Network.class);
+
+        final KeepalivePacketData kpd = new NattKeepalivePacketData(srcAddress, srcPort,
+                dstAddress, dstPort, new byte[] {1});
+        final KeepaliveInfo ki = mKeepaliveTracker.new KeepaliveInfo(cb, nai, kpd,
+                10 /* interval */, KeepaliveInfo.TYPE_NATT, fd);
+        mKeepaliveTracker.setReturnedKeepaliveInfo(ki);
+
+        mAOOKeepaliveTracker.startNattKeepalive(nai, fd, 10 /* intervalSeconds */, cb,
+                srcAddress.toString(), srcPort, dstAddress.toString(), dstPort,
+                true /* automaticOnOffKeepalives */, underpinnedNetwork);
+        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+
+        final ArgumentCaptor<AlarmManager.OnAlarmListener> listenerCaptor =
+                ArgumentCaptor.forClass(AlarmManager.OnAlarmListener.class);
+        verify(mAlarmManager).setExact(eq(AlarmManager.ELAPSED_REALTIME), anyLong(),
+                any(), listenerCaptor.capture(), eq(mTestHandler));
+        final AlarmManager.OnAlarmListener listener = listenerCaptor.getValue();
+
+        // For realism, the listener should be posted on the handler
+        mTestHandler.post(() -> listener.onAlarm());
+        // Wait for the listener to be called. The listener enqueues a message to the handler.
+        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+        // Wait for the message posted by the listener to be processed.
+        HandlerUtils.waitForIdle(mTestHandler, TIMEOUT_MS);
+
+        assertNotNull(mTestHandler.mLastAutoKi);
+        assertEquals(cb, mTestHandler.mLastAutoKi.getCallback());
+        assertEquals(underpinnedNetwork, mTestHandler.mLastAutoKi.getUnderpinnedNetwork());
+        socket.close();
+    }
+
     private void setupResponseWithSocketExisting() throws Exception {
         final ByteBuffer tcpBufferV6 = getByteBuffer(TEST_RESPONSE_BYTES);
         final ByteBuffer tcpBufferV4 = getByteBuffer(TEST_RESPONSE_BYTES);