Support keys that have uncommon width.

Bug: 8591918

Change-Id: I1e01e1560200333f9e35993af0aa7e5a17e6944f
diff --git a/native/jni/src/suggest/core/layout/proximity_info.cpp b/native/jni/src/suggest/core/layout/proximity_info.cpp
index 80355c1..05826a5 100644
--- a/native/jni/src/suggest/core/layout/proximity_info.cpp
+++ b/native/jni/src/suggest/core/layout/proximity_info.cpp
@@ -134,24 +134,13 @@
 }
 
 float ProximityInfo::getNormalizedSquaredDistanceFromCenterFloatG(
-        const int keyId, const int x, const int y, const float verticalScale) const {
-    const bool correctTouchPosition = hasTouchPositionCorrectionData();
-    const float centerX = static_cast<float>(correctTouchPosition ? getSweetSpotCenterXAt(keyId)
-            : getKeyCenterXOfKeyIdG(keyId));
-    const float visualKeyCenterY = static_cast<float>(getKeyCenterYOfKeyIdG(keyId));
-    float centerY;
-    if (correctTouchPosition) {
-        const float sweetSpotCenterY = static_cast<float>(getSweetSpotCenterYAt(keyId));
-        const float gapY = sweetSpotCenterY - visualKeyCenterY;
-        centerY = visualKeyCenterY + gapY * verticalScale;
-    } else {
-        centerY = visualKeyCenterY;
-    }
+        const int keyId, const int x, const int y, const bool isGeometric) const {
+    const float centerX = static_cast<float>(getKeyCenterXOfKeyIdG(keyId, x, isGeometric));
+    const float centerY = static_cast<float>(getKeyCenterYOfKeyIdG(keyId, y, isGeometric));
     const float touchX = static_cast<float>(x);
     const float touchY = static_cast<float>(y);
-    const float keyWidth = static_cast<float>(getMostCommonKeyWidth());
     return ProximityInfoUtils::getSquaredDistanceFloat(centerX, centerY, touchX, touchY)
-            / GeometryUtils::SQUARE_FLOAT(keyWidth);
+            / GeometryUtils::SQUARE_FLOAT(static_cast<float>(getMostCommonKeyWidth()));
 }
 
 int ProximityInfo::getCodePointOf(const int keyIndex) const {
@@ -168,41 +157,80 @@
         const int lowerCode = CharUtils::toLowerCase(code);
         mCenterXsG[i] = mKeyXCoordinates[i] + mKeyWidths[i] / 2;
         mCenterYsG[i] = mKeyYCoordinates[i] + mKeyHeights[i] / 2;
+        if (hasTouchPositionCorrectionData()) {
+            // Computes sweet spot center points for geometric input.
+            const float verticalScale = ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G;
+            const float sweetSpotCenterY = static_cast<float>(mSweetSpotCenterYs[i]);
+            const float gapY = sweetSpotCenterY - mCenterYsG[i];
+            mSweetSpotCenterYsG[i] = static_cast<int>(mCenterYsG[i] + gapY * verticalScale);
+        }
         mCodeToKeyMap[lowerCode] = i;
         mKeyIndexToCodePointG[i] = lowerCode;
     }
     for (int i = 0; i < KEY_COUNT; i++) {
         mKeyKeyDistancesG[i][i] = 0;
         for (int j = i + 1; j < KEY_COUNT; j++) {
-            mKeyKeyDistancesG[i][j] = GeometryUtils::getDistanceInt(
-                    mCenterXsG[i], mCenterYsG[i], mCenterXsG[j], mCenterYsG[j]);
+            if (hasTouchPositionCorrectionData()) {
+                // Computes distances using sweet spots if they exist.
+                // We have two types of Y coordinate sweet spots, for geometric and for the others.
+                // The sweet spots for geometric input are used for calculating key-key distances
+                // here.
+                mKeyKeyDistancesG[i][j] = GeometryUtils::getDistanceInt(
+                        mSweetSpotCenterXs[i], mSweetSpotCenterYsG[i],
+                        mSweetSpotCenterXs[j], mSweetSpotCenterYsG[j]);
+            } else {
+                mKeyKeyDistancesG[i][j] = GeometryUtils::getDistanceInt(
+                        mCenterXsG[i], mCenterYsG[i], mCenterXsG[j], mCenterYsG[j]);
+            }
             mKeyKeyDistancesG[j][i] = mKeyKeyDistancesG[i][j];
         }
     }
 }
 
-int ProximityInfo::getKeyCenterXOfCodePointG(int charCode) const {
-    return getKeyCenterXOfKeyIdG(
-            ProximityInfoUtils::getKeyIndexOf(KEY_COUNT, charCode, &mCodeToKeyMap));
-}
-
-int ProximityInfo::getKeyCenterYOfCodePointG(int charCode) const {
-    return getKeyCenterYOfKeyIdG(
-            ProximityInfoUtils::getKeyIndexOf(KEY_COUNT, charCode, &mCodeToKeyMap));
-}
-
-int ProximityInfo::getKeyCenterXOfKeyIdG(int keyId) const {
-    if (keyId >= 0) {
-        return mCenterXsG[keyId];
+// referencePointX is used only for keys wider than most common key width. When the referencePointX
+// is NOT_A_COORDINATE, this method calculates the return value without using the line segment.
+// isGeometric is currently not used because we don't have extra X coordinates sweet spots for
+// geometric input.
+int ProximityInfo::getKeyCenterXOfKeyIdG(
+        const int keyId, const int referencePointX, const bool isGeometric) const {
+    if (keyId < 0) {
+        return 0;
     }
-    return 0;
+    int centerX = (hasTouchPositionCorrectionData()) ? static_cast<int>(mSweetSpotCenterXs[keyId])
+            : mCenterXsG[keyId];
+    const int keyWidth = mKeyWidths[keyId];
+    if (referencePointX != NOT_A_COORDINATE
+            && keyWidth > getMostCommonKeyWidth()) {
+        // For keys wider than most common keys, we use a line segment instead of the center point;
+        // thus, centerX is adjusted depending on referencePointX.
+        const int keyWidthHalfDiff = (keyWidth - getMostCommonKeyWidth()) / 2;
+        if (referencePointX < centerX - keyWidthHalfDiff) {
+            centerX -= keyWidthHalfDiff;
+        } else if (referencePointX > centerX + keyWidthHalfDiff) {
+            centerX += keyWidthHalfDiff;
+        } else {
+            centerX = referencePointX;
+        }
+    }
+    return centerX;
 }
 
-int ProximityInfo::getKeyCenterYOfKeyIdG(int keyId) const {
-    if (keyId >= 0) {
+// referencePointY is currently not used because we don't specially handle keys higher than the
+// most common key height. When the referencePointY is NOT_A_COORDINATE, this method should
+// calculate the return value without using the line segment.
+int ProximityInfo::getKeyCenterYOfKeyIdG(
+        const int keyId,  const int referencePointY, const bool isGeometric) const {
+    // TODO: Remove "isGeometric" and have separate "proximity_info"s for gesture and typing.
+    if (keyId < 0) {
+        return 0;
+    }
+    if (!hasTouchPositionCorrectionData()) {
         return mCenterYsG[keyId];
+    } else if (isGeometric) {
+        return static_cast<int>(mSweetSpotCenterYsG[keyId]);
+    } else {
+        return static_cast<int>(mSweetSpotCenterYs[keyId]);
     }
-    return 0;
 }
 
 int ProximityInfo::getKeyKeyDistanceG(const int keyId0, const int keyId1) const {
diff --git a/native/jni/src/suggest/core/layout/proximity_info.h b/native/jni/src/suggest/core/layout/proximity_info.h
index 534c2c2..f259490 100644
--- a/native/jni/src/suggest/core/layout/proximity_info.h
+++ b/native/jni/src/suggest/core/layout/proximity_info.h
@@ -37,8 +37,7 @@
     bool hasSpaceProximity(const int x, const int y) const;
     int getNormalizedSquaredDistance(const int inputIndex, const int proximityIndex) const;
     float getNormalizedSquaredDistanceFromCenterFloatG(
-            const int keyId, const int x, const int y,
-            const float verticalScale) const;
+            const int keyId, const int x, const int y, const bool isGeometric) const;
     int getCodePointOf(const int keyIndex) const;
     bool hasSweetSpotData(const int keyIndex) const {
         // When there are no calibration data for a key,
@@ -65,10 +64,10 @@
     int getKeyboardHeight() const { return KEYBOARD_HEIGHT; }
     float getKeyboardHypotenuse() const { return KEYBOARD_HYPOTENUSE; }
 
-    int getKeyCenterXOfCodePointG(int charCode) const;
-    int getKeyCenterYOfCodePointG(int charCode) const;
-    int getKeyCenterXOfKeyIdG(int keyId) const;
-    int getKeyCenterYOfKeyIdG(int keyId) const;
+    int getKeyCenterXOfKeyIdG(
+            const int keyId, const int referencePointX, const bool isGeometric) const;
+    int getKeyCenterYOfKeyIdG(
+            const int keyId, const int referencePointY, const bool isGeometric) const;
     int getKeyKeyDistanceG(int keyId0, int keyId1) const;
 
     AK_FORCE_INLINE void initializeProximities(const int *const inputCodes,
@@ -115,6 +114,8 @@
     int mKeyCodePoints[MAX_KEY_COUNT_IN_A_KEYBOARD];
     float mSweetSpotCenterXs[MAX_KEY_COUNT_IN_A_KEYBOARD];
     float mSweetSpotCenterYs[MAX_KEY_COUNT_IN_A_KEYBOARD];
+    // Sweet spots for geometric input. Note that we have extra sweet spots only for Y coordinates.
+    float mSweetSpotCenterYsG[MAX_KEY_COUNT_IN_A_KEYBOARD];
     float mSweetSpotRadii[MAX_KEY_COUNT_IN_A_KEYBOARD];
     hash_map_compat<int, int> mCodeToKeyMap;
 
diff --git a/native/jni/src/suggest/core/layout/proximity_info_state.cpp b/native/jni/src/suggest/core/layout/proximity_info_state.cpp
index e8d9500..7780efd 100644
--- a/native/jni/src/suggest/core/layout/proximity_info_state.cpp
+++ b/native/jni/src/suggest/core/layout/proximity_info_state.cpp
@@ -97,15 +97,10 @@
                 pushTouchPointStartIndex, lastSavedInputSize);
     }
 
-    // TODO: Remove the dependency of "isGeometric"
-    const float verticalSweetSpotScale = isGeometric
-            ? ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE_G
-            : ProximityInfoParams::VERTICAL_SWEET_SPOT_SCALE;
-
     if (xCoordinates && yCoordinates) {
         mSampledInputSize = ProximityInfoStateUtils::updateTouchPoints(mProximityInfo,
                 mMaxPointToKeyLength, mInputProximities, xCoordinates, yCoordinates, times,
-                pointerIds, verticalSweetSpotScale, inputSize, isGeometric, pointerId,
+                pointerIds, inputSize, isGeometric, pointerId,
                 pushTouchPointStartIndex, &mSampledInputXs, &mSampledInputYs, &mSampledTimes,
                 &mSampledLengthCache, &mSampledInputIndice);
     }
@@ -123,7 +118,7 @@
 
     if (mSampledInputSize > 0) {
         ProximityInfoStateUtils::initGeometricDistanceInfos(mProximityInfo, mSampledInputSize,
-                lastSavedInputSize, verticalSweetSpotScale, &mSampledInputXs, &mSampledInputYs,
+                lastSavedInputSize, isGeometric, &mSampledInputXs, &mSampledInputYs,
                 &mSampledNearKeySets, &mSampledNormalizedSquaredLengthCache);
         if (isGeometric) {
             // updates probabilities of skipping or mapping each key for all points.
diff --git a/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp
index 1bbae65..904671f 100644
--- a/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp
+++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.cpp
@@ -43,8 +43,8 @@
         const ProximityInfo *const proximityInfo, const int maxPointToKeyLength,
         const int *const inputProximities, const int *const inputXCoordinates,
         const int *const inputYCoordinates, const int *const times, const int *const pointerIds,
-        const float verticalSweetSpotScale, const int inputSize, const bool isGeometric,
-        const int pointerId, const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs,
+        const int inputSize, const bool isGeometric, const int pointerId,
+        const int pushTouchPointStartIndex, std::vector<int> *sampledInputXs,
         std::vector<int> *sampledInputYs, std::vector<int> *sampledInputTimes,
         std::vector<int> *sampledLengthCache, std::vector<int> *sampledInputIndice) {
     if (DEBUG_SAMPLING_POINTS) {
@@ -113,7 +113,7 @@
             }
 
             if (pushTouchPoint(proximityInfo, maxPointToKeyLength, i, c, x, y, time,
-                    verticalSweetSpotScale, isGeometric /* doSampling */, i == lastInputIndex,
+                    isGeometric, isGeometric /* doSampling */, i == lastInputIndex,
                     sumAngle, currentNearKeysDistances, prevNearKeysDistances,
                     prevPrevNearKeysDistances, sampledInputXs, sampledInputYs, sampledInputTimes,
                     sampledLengthCache, sampledInputIndice)) {
@@ -183,7 +183,7 @@
 
 /* static */ void ProximityInfoStateUtils::initGeometricDistanceInfos(
         const ProximityInfo *const proximityInfo, const int sampledInputSize,
-        const int lastSavedInputSize, const float verticalSweetSpotScale,
+        const int lastSavedInputSize, const bool isGeometric,
         const std::vector<int> *const sampledInputXs,
         const std::vector<int> *const sampledInputYs,
         std::vector<NearKeycodesSet> *sampledNearKeySets,
@@ -199,7 +199,7 @@
             const int y = (*sampledInputYs)[i];
             const float normalizedSquaredDistance =
                     proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(
-                            k, x, y, verticalSweetSpotScale);
+                            k, x, y, isGeometric);
             (*sampledNormalizedSquaredLengthCache)[index] = normalizedSquaredDistance;
             if (normalizedSquaredDistance
                     < ProximityInfoParams::NEAR_KEY_NORMALIZED_SQUARED_THRESHOLD) {
@@ -317,14 +317,13 @@
 // the given point and the nearest key position.
 /* static */ float ProximityInfoStateUtils::updateNearKeysDistances(
         const ProximityInfo *const proximityInfo, const float maxPointToKeyLength, const int x,
-        const int y, const float verticalSweetspotScale,
-        NearKeysDistanceMap *const currentNearKeysDistances) {
+        const int y, const bool isGeometric, NearKeysDistanceMap *const currentNearKeysDistances) {
     currentNearKeysDistances->clear();
     const int keyCount = proximityInfo->getKeyCount();
     float nearestKeyDistance = maxPointToKeyLength;
     for (int k = 0; k < keyCount; ++k) {
         const float dist = proximityInfo->getNormalizedSquaredDistanceFromCenterFloatG(k, x, y,
-                verticalSweetspotScale);
+                isGeometric);
         if (dist < ProximityInfoParams::NEAR_KEY_THRESHOLD_FOR_DISTANCE) {
             currentNearKeysDistances->insert(std::pair<int, float>(k, dist));
         }
@@ -405,7 +404,7 @@
 // Returning if previous point is popped or not.
 /* static */ bool ProximityInfoStateUtils::pushTouchPoint(const ProximityInfo *const proximityInfo,
         const int maxPointToKeyLength, const int inputIndex, const int nodeCodePoint, int x, int y,
-        const int time, const float verticalSweetSpotScale, const bool doSampling,
+        const int time, const bool isGeometric, const bool doSampling,
         const bool isLastPoint, const float sumAngle,
         NearKeysDistanceMap *const currentNearKeysDistances,
         const NearKeysDistanceMap *const prevNearKeysDistances,
@@ -419,7 +418,7 @@
     bool popped = false;
     if (nodeCodePoint < 0 && doSampling) {
         const float nearest = updateNearKeysDistances(proximityInfo, maxPointToKeyLength, x, y,
-                verticalSweetSpotScale, currentNearKeysDistances);
+                isGeometric, currentNearKeysDistances);
         const float score = getPointScore(mostCommonKeyWidth, x, y, time, isLastPoint, nearest,
                 sumAngle, currentNearKeysDistances, prevNearKeysDistances,
                 prevPrevNearKeysDistances, sampledInputXs, sampledInputYs);
@@ -453,8 +452,8 @@
     if (nodeCodePoint >= 0 && (x < 0 || y < 0)) {
         const int keyId = proximityInfo->getKeyIndexOf(nodeCodePoint);
         if (keyId >= 0) {
-            x = proximityInfo->getKeyCenterXOfKeyIdG(keyId);
-            y = proximityInfo->getKeyCenterYOfKeyIdG(keyId);
+            x = proximityInfo->getKeyCenterXOfKeyIdG(keyId, NOT_AN_INDEX, isGeometric);
+            y = proximityInfo->getKeyCenterYOfKeyIdG(keyId, NOT_AN_INDEX, isGeometric);
         }
     }
 
diff --git a/native/jni/src/suggest/core/layout/proximity_info_state_utils.h b/native/jni/src/suggest/core/layout/proximity_info_state_utils.h
index 66fe079..6de9700 100644
--- a/native/jni/src/suggest/core/layout/proximity_info_state_utils.h
+++ b/native/jni/src/suggest/core/layout/proximity_info_state_utils.h
@@ -38,8 +38,7 @@
     static int updateTouchPoints(const ProximityInfo *const proximityInfo,
             const int maxPointToKeyLength, const int *const inputProximities,
             const int *const inputXCoordinates, const int *const inputYCoordinates,
-            const int *const times, const int *const pointerIds,
-            const float verticalSweetSpotScale, const int inputSize,
+            const int *const times, const int *const pointerIds, const int inputSize,
             const bool isGeometric, const int pointerId, const int pushTouchPointStartIndex,
             std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs,
             std::vector<int> *sampledInputTimes, std::vector<int> *sampledLengthCache,
@@ -84,8 +83,7 @@
             const std::vector<float> *const sampledNormalizedSquaredLengthCache, const int keyCount,
             const int inputIndex, const int keyId);
     static void initGeometricDistanceInfos(const ProximityInfo *const proximityInfo,
-            const int sampledInputSize, const int lastSavedInputSize,
-            const float verticalSweetSpotScale,
+            const int sampledInputSize, const int lastSavedInputSize, const bool isGeometric,
             const std::vector<int> *const sampledInputXs,
             const std::vector<int> *const sampledInputYs,
             std::vector<NearKeycodesSet> *sampledNearKeySets,
@@ -120,7 +118,7 @@
 
     static float updateNearKeysDistances(const ProximityInfo *const proximityInfo,
             const float maxPointToKeyLength, const int x, const int y,
-            const float verticalSweetSpotScale,
+            const bool isGeometric,
             NearKeysDistanceMap *const currentNearKeysDistances);
     static bool isPrevLocalMin(const NearKeysDistanceMap *const currentNearKeysDistances,
             const NearKeysDistanceMap *const prevNearKeysDistances,
@@ -133,7 +131,7 @@
             std::vector<int> *sampledInputXs, std::vector<int> *sampledInputYs);
     static bool pushTouchPoint(const ProximityInfo *const proximityInfo,
             const int maxPointToKeyLength, const int inputIndex, const int nodeCodePoint, int x,
-            int y, const int time, const float verticalSweetSpotScale,
+            int y, const int time, const bool isGeometric,
             const bool doSampling, const bool isLastPoint,
             const float sumAngle, NearKeysDistanceMap *const currentNearKeysDistances,
             const NearKeysDistanceMap *const prevNearKeysDistances,