Merge "[Autofill Framework] Fix authentication id by limiting request id to only 15 bits." into main
diff --git a/services/autofill/java/com/android/server/autofill/RequestId.java b/services/autofill/java/com/android/server/autofill/RequestId.java
index 29ad786..d8069a8 100644
--- a/services/autofill/java/com/android/server/autofill/RequestId.java
+++ b/services/autofill/java/com/android/server/autofill/RequestId.java
@@ -16,8 +16,14 @@
 
 package com.android.server.autofill;
 
-import java.util.List;
+import static com.android.server.autofill.Helper.sDebug;
+
+import android.util.Slog;
+import android.util.SparseArray;
+
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.List;
+import java.util.Random;
 
 // Helper class containing various methods to deal with FillRequest Ids.
 // For authentication flows, there needs to be a way to know whether to retrieve the Fill
@@ -25,56 +31,97 @@
 // way to achieve this is by assigning odd number request ids to secondary provider and
 // even numbers to primary provider.
 public class RequestId {
+    private AtomicInteger sIdCounter;
 
-  private AtomicInteger sIdCounter;
+    // The minimum request id is 2 to avoid possible authentication issues.
+    static final int MIN_REQUEST_ID = 2;
+    // The maximum request id is 0x7FFF to make sure the 16th bit is 0.
+    // This is to make sure the authentication id is always positive.
+    static final int MAX_REQUEST_ID = 0x7FFF; // 32767
 
-  // Mainly used for tests
-  RequestId(int start) {
-    sIdCounter = new AtomicInteger(start);
-  }
+    // The maximum start id is made small to best avoid wrapping around.
+    static final int MAX_START_ID = 1000;
+    // The magic number is used to determine if a wrap has happened.
+    // The underlying assumption of MAGIC_NUMBER is that there can't be as many as MAGIC_NUMBER
+    // of fill requests in one session. so there can't be as many as MAGIC_NUMBER of fill requests
+    // getting dropped.
+    static final int MAGIC_NUMBER = 5000;
 
-  public RequestId() {
-    this((int) (Math.floor(Math.random() * 0xFFFF)));
-  }
+    static final int MIN_PRIMARY_REQUEST_ID = 2;
+    static final int MAX_PRIMARY_REQUEST_ID = 0x7FFE; // 32766
 
-  public static int getLastRequestIdIndex(List<Integer> requestIds) {
-    int lastId = -1;
-    int indexOfBiggest = -1;
-    // Biggest number is usually the latest request, since IDs only increase
-    // The only exception is when the request ID wraps around back to 0
-      for (int i = requestIds.size() - 1; i >= 0; i--) {
-        if (requestIds.get(i) > lastId) {
-        lastId = requestIds.get(i);
-        indexOfBiggest = i;
-      }
-    }
+    static final int MIN_SECONDARY_REQUEST_ID = 3;
+    static final int MAX_SECONDARY_REQUEST_ID = 0x7FFF; // 32767
 
-    // 0xFFFE + 2 == 0x1 (for secondary)
-    // 0xFFFD + 2 == 0x0 (for primary)
-    // Wrap has occurred
-    if (lastId >= 0xFFFD) {
-      // Calculate the biggest size possible
-      // If list only has one kind of request ids - we need to multiple by 2
-      // (since they skip odd ints)
-      // Also subtract one from size because at least one integer exists pre-wrap
-      int calcSize = (requestIds.size()) * 2;
-      //Biggest possible id after wrapping
-      int biggestPossible = (lastId + calcSize) % 0xFFFF;
-      lastId = -1;
-      indexOfBiggest = -1;
-      for (int i = 0; i < requestIds.size(); i++) {
-        int currentId = requestIds.get(i);
-        if (currentId <= biggestPossible && currentId > lastId) {
-          lastId = currentId;
-          indexOfBiggest = i;
+    private static final String TAG = "RequestId";
+
+    // WARNING: This constructor should only be used for testing
+    RequestId(int startId) {
+        if (startId < MIN_REQUEST_ID || startId > MAX_REQUEST_ID) {
+            throw new IllegalArgumentException("startId must be between " + MIN_REQUEST_ID +
+                                                   " and " + MAX_REQUEST_ID);
         }
-      }
+        if (sDebug) {
+            Slog.d(TAG, "RequestId(int): startId= " + startId);
+        }
+        sIdCounter = new AtomicInteger(startId);
     }
 
-    return indexOfBiggest;
-  }
+    // WARNING: This get method should only be used for testing
+    int getRequestId() {
+        return sIdCounter.get();
+    }
 
-  public int nextId(boolean isSecondary) {
+    public RequestId() {
+        Random random = new Random();
+        int low = MIN_REQUEST_ID;
+        int high = MAX_START_ID + 1; // nextInt is exclusive on upper limit
+
+        // Generate a random start request id that >= MIN_REQUEST_ID and <= MAX_START_ID
+        int startId = random.nextInt(high - low) + low;
+        if (sDebug) {
+            Slog.d(TAG, "RequestId(): startId= " + startId);
+        }
+        sIdCounter = new AtomicInteger(startId);
+    }
+
+    // Given a list of request ids, find the index of the last request id.
+    // Note: Since the request id wraps around, the largest request id may not be
+    // the latest request id.
+    //
+    // @param requestIds List of request ids in ascending order with at least one element.
+    // @return Index of the last request id.
+    public static int getLastRequestIdIndex(List<Integer> requestIds) {
+        // If there is only one request id, return index as 0.
+        if (requestIds.size() == 1) {
+            return 0;
+        }
+
+        // We have to use a magical number to determine if a wrap has happened because
+        // the request id could be lost. The underlying assumption of MAGIC_NUMBER is that
+        // there can't be as many as MAGIC_NUMBER of fill requests in one session.
+        boolean wrapHasHappened = false;
+        int latestRequestIdIndex = -1;
+
+        for (int i = 0; i < requestIds.size() - 1; i++) {
+            if (requestIds.get(i+1) - requestIds.get(i) > MAGIC_NUMBER) {
+                wrapHasHappened = true;
+                latestRequestIdIndex = i;
+                break;
+            }
+        }
+
+        // If there was no wrap, the last request index is the last index.
+        if (!wrapHasHappened) {
+            latestRequestIdIndex = requestIds.size() - 1;
+        }
+        if (sDebug) {
+            Slog.d(TAG, "getLastRequestIdIndex(): latestRequestIdIndex = " + latestRequestIdIndex);
+        }
+        return latestRequestIdIndex;
+    }
+
+    public int nextId(boolean isSecondary) {
         // For authentication flows, there needs to be a way to know whether to retrieve the Fill
         // Response from the primary provider or the secondary provider from the requestId. A simple
         // way to achieve this is by assigning odd number request ids to secondary provider and
@@ -82,13 +129,20 @@
         int requestId;
 
         do {
-            requestId = sIdCounter.incrementAndGet() % 0xFFFF;
+            requestId = sIdCounter.incrementAndGet() % (MAX_REQUEST_ID + 1);
+            // Skip numbers smaller than MIN_REQUEST_ID to avoid possible authentication issue
+            if (requestId < MIN_REQUEST_ID) {
+                requestId = MIN_REQUEST_ID;
+            }
             sIdCounter.set(requestId);
         } while (isSecondaryProvider(requestId) != isSecondary);
+        if (sDebug) {
+            Slog.d(TAG, "nextId(): requestId = " + requestId);
+        }
         return requestId;
-  }
+    }
 
-  public static boolean isSecondaryProvider(int requestId) {
-      return requestId % 2 == 1;
-  }
+    public static boolean isSecondaryProvider(int requestId) {
+        return requestId % 2 == 1;
+    }
 }
diff --git a/services/autofill/java/com/android/server/autofill/Session.java b/services/autofill/java/com/android/server/autofill/Session.java
index 494e956..c6ddc16 100644
--- a/services/autofill/java/com/android/server/autofill/Session.java
+++ b/services/autofill/java/com/android/server/autofill/Session.java
@@ -6902,17 +6902,18 @@
         return mPendingSaveUi != null && mPendingSaveUi.getState() == PendingUi.STATE_PENDING;
     }
 
+    // Return latest response index in mResponses SparseArray.
     @GuardedBy("mLock")
     private int getLastResponseIndexLocked() {
-        if (mResponses != null) {
-            List<Integer> requestIdList = new ArrayList<>();
-            final int responseCount = mResponses.size();
-            for (int i = 0; i < responseCount; i++) {
-                requestIdList.add(mResponses.keyAt(i));
-            }
-            return mRequestId.getLastRequestIdIndex(requestIdList);
+        if (mResponses == null  || mResponses.size() == 0) {
+          return -1;
         }
-        return -1;
+        List<Integer> requestIdList = new ArrayList<>();
+        final int responseCount = mResponses.size();
+        for (int i = 0; i < responseCount; i++) {
+            requestIdList.add(mResponses.keyAt(i));
+        }
+        return mRequestId.getLastRequestIdIndex(requestIdList);
     }
 
     private LogMaker newLogMaker(int category) {
diff --git a/services/tests/servicestests/src/com/android/server/autofill/RequestIdTest.java b/services/tests/servicestests/src/com/android/server/autofill/RequestIdTest.java
index 6d56c41..60c3659 100644
--- a/services/tests/servicestests/src/com/android/server/autofill/RequestIdTest.java
+++ b/services/tests/servicestests/src/com/android/server/autofill/RequestIdTest.java
@@ -17,17 +17,25 @@
 
 import static com.google.common.truth.Truth.assertThat;
 
+import android.util.Slog;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
-import java.util.ArrayList;
-import java.util.List;
-
 @RunWith(JUnit4.class)
 public class RequestIdTest {
 
+    private static final int TEST_DATASET_SIZE = 300;
+    private static final int TEST_WRAP_SIZE = 50; // Number of request ids before wrap happens
+    private static final String TAG = "RequestIdTest";
+
     List<Integer> datasetPrimaryNoWrap = new ArrayList<>();
     List<Integer> datasetPrimaryWrap = new ArrayList<>();
     List<Integer> datasetSecondaryNoWrap = new ArrayList<>();
@@ -35,151 +43,200 @@
     List<Integer> datasetMixedNoWrap = new ArrayList<>();
     List<Integer> datasetMixedWrap = new ArrayList<>();
 
-    @Before
-    public void setup() throws Exception {
-      int datasetSize = 300;
+    List<Integer> manualWrapRequestIdList = Arrays.asList(3, 9, 15,
+                                                            RequestId.MAX_SECONDARY_REQUEST_ID - 5,
+                                                            RequestId.MAX_SECONDARY_REQUEST_ID - 3);
+    List<Integer> manualNoWrapRequestIdList =Arrays.asList(2, 6, 10, 14, 18, 22, 26, 30);
 
+    List<Integer> manualOneElementRequestIdList = Arrays.asList(1);
+
+    @Before
+    public void setup() throws IllegalArgumentException {
+        Slog.d(TAG, "setup()");
         { // Generate primary only ids that do not wrap
-            RequestId requestId = new RequestId(0);
-            for (int i = 0; i < datasetSize; i++) {
+            RequestId requestId = new RequestId(RequestId.MIN_PRIMARY_REQUEST_ID);
+            for (int i = 0; i < TEST_DATASET_SIZE; i++) {
                 datasetPrimaryNoWrap.add(requestId.nextId(false));
             }
+            Collections.sort(datasetPrimaryNoWrap);
         }
 
         { // Generate primary only ids that wrap
-            RequestId requestId = new RequestId(0xff00);
-            for (int i = 0; i < datasetSize; i++) {
+            RequestId requestId = new RequestId(RequestId.MAX_PRIMARY_REQUEST_ID -
+                                                    TEST_WRAP_SIZE * 2);
+            for (int i = 0; i < TEST_DATASET_SIZE; i++) {
                 datasetPrimaryWrap.add(requestId.nextId(false));
             }
+            Collections.sort(datasetPrimaryWrap);
         }
 
         { // Generate SECONDARY only ids that do not wrap
-            RequestId requestId = new RequestId(0);
-            for (int i = 0; i < datasetSize; i++) {
+            RequestId requestId = new RequestId(RequestId.MIN_SECONDARY_REQUEST_ID);
+            for (int i = 0; i < TEST_DATASET_SIZE; i++) {
                 datasetSecondaryNoWrap.add(requestId.nextId(true));
             }
+            Collections.sort(datasetSecondaryNoWrap);
         }
 
         { // Generate SECONDARY only ids that wrap
-            RequestId requestId = new RequestId(0xff00);
-            for (int i = 0; i < datasetSize; i++) {
+            RequestId requestId = new RequestId(RequestId.MAX_SECONDARY_REQUEST_ID -
+                                                    TEST_WRAP_SIZE * 2);
+            for (int i = 0; i < TEST_DATASET_SIZE; i++) {
                 datasetSecondaryWrap.add(requestId.nextId(true));
             }
+            Collections.sort(datasetSecondaryWrap);
         }
 
         { // Generate MIXED only ids that do not wrap
-            RequestId requestId = new RequestId(0);
-            for (int i = 0; i < datasetSize; i++) {
+            RequestId requestId = new RequestId(RequestId.MIN_REQUEST_ID);
+            for (int i = 0; i < TEST_DATASET_SIZE; i++) {
                 datasetMixedNoWrap.add(requestId.nextId(i % 2 != 0));
             }
+            Collections.sort(datasetMixedNoWrap);
         }
 
         { // Generate MIXED only ids that wrap
-            RequestId requestId = new RequestId(0xff00);
-            for (int i = 0; i < datasetSize; i++) {
+            RequestId requestId = new RequestId(RequestId.MAX_REQUEST_ID -
+                                                    TEST_WRAP_SIZE);
+            for (int i = 0; i < TEST_DATASET_SIZE; i++) {
                 datasetMixedWrap.add(requestId.nextId(i % 2 != 0));
             }
+            Collections.sort(datasetMixedWrap);
         }
+        Slog.d(TAG, "finishing setup()");
     }
 
     @Test
     public void testRequestIdLists() {
+        Slog.d(TAG, "testRequestIdLists()");
         for (int id : datasetPrimaryNoWrap) {
             assertThat(RequestId.isSecondaryProvider(id)).isFalse();
-            assertThat(id >= 0).isTrue();
-            assertThat(id < 0xffff).isTrue();
+            assertThat(id).isAtLeast(RequestId.MIN_PRIMARY_REQUEST_ID);
+            assertThat(id).isAtMost(RequestId.MAX_PRIMARY_REQUEST_ID);
         }
 
         for (int id : datasetPrimaryWrap) {
             assertThat(RequestId.isSecondaryProvider(id)).isFalse();
-            assertThat(id >= 0).isTrue();
-            assertThat(id < 0xffff).isTrue();
+            assertThat(id).isAtLeast(RequestId.MIN_PRIMARY_REQUEST_ID);
+            assertThat(id).isAtMost(RequestId.MAX_PRIMARY_REQUEST_ID);
         }
 
         for (int id : datasetSecondaryNoWrap) {
             assertThat(RequestId.isSecondaryProvider(id)).isTrue();
-            assertThat(id >= 0).isTrue();
-            assertThat(id < 0xffff).isTrue();
+            assertThat(id).isAtLeast(RequestId.MIN_SECONDARY_REQUEST_ID);
+            assertThat(id).isAtMost(RequestId.MAX_SECONDARY_REQUEST_ID);
         }
 
         for (int id : datasetSecondaryWrap) {
             assertThat(RequestId.isSecondaryProvider(id)).isTrue();
-            assertThat(id >= 0).isTrue();
-            assertThat(id < 0xffff).isTrue();
+            assertThat(id).isAtLeast(RequestId.MIN_SECONDARY_REQUEST_ID);
+            assertThat(id).isAtMost(RequestId.MAX_SECONDARY_REQUEST_ID);
         }
     }
 
     @Test
-    public void testRequestIdGeneration() {
-        RequestId requestId = new RequestId(0);
+    public void testCreateNewRequestId() {
+        Slog.d(TAG, "testCreateNewRequestId()");
+        for (int i = 0; i < 100000; i++) {
+            RequestId requestId = new RequestId();
+            assertThat(requestId.getRequestId()).isAtLeast(RequestId.MIN_REQUEST_ID);
+            assertThat(requestId.getRequestId()).isAtMost(RequestId.MAX_START_ID);
+        }
+    }
 
+    @Test
+    public void testGetNextRequestId() throws IllegalArgumentException{
+        Slog.d(TAG, "testGetNextRequestId()");
+        RequestId requestId = new RequestId();
         // Large Primary
         for (int i = 0; i < 100000; i++) {
             int y = requestId.nextId(false);
             assertThat(RequestId.isSecondaryProvider(y)).isFalse();
-            assertThat(y >= 0).isTrue();
-            assertThat(y < 0xffff).isTrue();
+            assertThat(y).isAtLeast(RequestId.MIN_PRIMARY_REQUEST_ID);
+            assertThat(y).isAtMost(RequestId.MAX_PRIMARY_REQUEST_ID);
         }
 
         // Large Secondary
-        requestId = new RequestId(0);
+        requestId = new RequestId();
         for (int i = 0; i < 100000; i++) {
             int y = requestId.nextId(true);
             assertThat(RequestId.isSecondaryProvider(y)).isTrue();
-            assertThat(y >= 0).isTrue();
-            assertThat(y < 0xffff).isTrue();
+            assertThat(y).isAtLeast(RequestId.MIN_SECONDARY_REQUEST_ID);
+            assertThat(y).isAtMost(RequestId.MAX_SECONDARY_REQUEST_ID);
         }
 
         // Large Mixed
-        requestId = new RequestId(0);
+        requestId = new RequestId();
         for (int i = 0; i < 50000; i++) {
             int y = requestId.nextId(i % 2 != 0);
-            assertThat(RequestId.isSecondaryProvider(y)).isEqualTo(i % 2 == 0);
-            assertThat(y >= 0).isTrue();
-            assertThat(y < 0xffff).isTrue();
+            assertThat(y).isAtLeast(RequestId.MIN_REQUEST_ID);
+            assertThat(y).isAtMost(RequestId.MAX_REQUEST_ID);
         }
     }
 
     @Test
     public void testGetLastRequestId() {
-        // In this test, request ids are generated FIFO, so the last entry is also the last
-        // request
+        Slog.d(TAG, "testGetLastRequestId()");
 
-        { // Primary no wrap
-          int lastIdIndex = datasetPrimaryNoWrap.size() - 1;
-          int lastComputedIdIndex = RequestId.getLastRequestIdIndex(datasetPrimaryNoWrap);
-          assertThat(lastIdIndex).isEqualTo(lastComputedIdIndex);
-        }
-
-        { // Primary wrap
-            int lastIdIndex = datasetPrimaryWrap.size() - 1;
-            int lastComputedIdIndex = RequestId.getLastRequestIdIndex(datasetPrimaryWrap);
+        {   // Primary no wrap
+            int lastIdIndex = datasetPrimaryNoWrap.size() - 1;
+            int lastComputedIdIndex = RequestId.getLastRequestIdIndex(datasetPrimaryNoWrap);
             assertThat(lastIdIndex).isEqualTo(lastComputedIdIndex);
         }
 
-        { // Secondary no wrap
+        {   // Primary wrap
+            // The last index would be the # of request ids left after wrap
+            // minus 1 (index starts at 0)
+            int lastIdIndex = TEST_DATASET_SIZE - TEST_WRAP_SIZE - 1;
+            int lastComputedIdIndex = RequestId.getLastRequestIdIndex(datasetPrimaryWrap);
+            assertThat(lastComputedIdIndex).isEqualTo(lastIdIndex);
+        }
+
+        {   // Secondary no wrap
             int lastIdIndex = datasetSecondaryNoWrap.size() - 1;
             int lastComputedIdIndex = RequestId.getLastRequestIdIndex(datasetSecondaryNoWrap);
             assertThat(lastIdIndex).isEqualTo(lastComputedIdIndex);
         }
 
-        { // Secondary wrap
-            int lastIdIndex = datasetSecondaryWrap.size() - 1;
+        {   // Secondary wrap
+            int lastIdIndex = TEST_DATASET_SIZE - TEST_WRAP_SIZE - 1;
             int lastComputedIdIndex = RequestId.getLastRequestIdIndex(datasetSecondaryWrap);
             assertThat(lastIdIndex).isEqualTo(lastComputedIdIndex);
         }
 
-        { // Mixed no wrap
+        {   // Mixed no wrap
             int lastIdIndex = datasetMixedNoWrap.size() - 1;
             int lastComputedIdIndex = RequestId.getLastRequestIdIndex(datasetMixedNoWrap);
             assertThat(lastIdIndex).isEqualTo(lastComputedIdIndex);
         }
 
-        { // Mixed wrap
-            int lastIdIndex = datasetMixedWrap.size() - 1;
+        {   // Mixed wrap
+            int lastIdIndex = TEST_DATASET_SIZE - TEST_WRAP_SIZE - 1;
             int lastComputedIdIndex = RequestId.getLastRequestIdIndex(datasetMixedWrap);
             assertThat(lastIdIndex).isEqualTo(lastComputedIdIndex);
         }
 
+        {   // Manual wrap
+            int lastIdIndex = 2; // [3, 9, 15,
+                                 // MAX_SECONDARY_REQUEST_ID - 5, MAX_SECONDARY_REQUEST_ID - 3]
+            int lastComputedIdIndex = RequestId.getLastRequestIdIndex(manualWrapRequestIdList);
+            assertThat(lastIdIndex).isEqualTo(lastComputedIdIndex);
+        }
+
+        {   // Manual no wrap
+            int lastIdIndex = manualNoWrapRequestIdList.size() - 1; // [2, 6, 10, 14,
+                                                                    // 18, 22, 26, 30]
+            int lastComputedIdIndex = RequestId.getLastRequestIdIndex(manualNoWrapRequestIdList);
+            assertThat(lastIdIndex).isEqualTo(lastComputedIdIndex);
+
+        }
+
+        {   // Manual one element
+            int lastIdIndex = 0; // [1]
+            int lastComputedIdIndex = RequestId.getLastRequestIdIndex(
+                manualOneElementRequestIdList);
+            assertThat(lastIdIndex).isEqualTo(lastComputedIdIndex);
+
+        }
     }
 }