Support terminal insertion error correction

Bug: 9421356

Change-Id: I19685763ca487b5636019d62e150708c63ce6fc2
diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h
index 974bb48..34a646f 100644
--- a/native/jni/src/defines.h
+++ b/native/jni/src/defines.h
@@ -381,6 +381,7 @@
     CT_TRANSPOSITION,
     CT_COMPLETION,
     CT_TERMINAL,
+    CT_TERMINAL_INSERTION,
     // Create new word with space omission
     CT_NEW_WORD_SPACE_OMITTION,
     // Create new word with space substitution
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_profiler.h b/native/jni/src/suggest/core/dicnode/dic_node_profiler.h
index 90f75d0..1f4d257 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_profiler.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_profiler.h
@@ -31,6 +31,7 @@
 #define PROF_TRANSPOSITION(profiler) profiler.profTransposition()
 #define PROF_NEARESTKEY(profiler) profiler.profNearestKey()
 #define PROF_TERMINAL(profiler) profiler.profTerminal()
+#define PROF_TERMINAL_INSERTION(profiler) profiler.profTerminalInsertion()
 #define PROF_NEW_WORD(profiler) profiler.profNewWord()
 #define PROF_NEW_WORD_BIGRAM(profiler) profiler.profNewWordBigram()
 #define PROF_NODE_RESET(profiler) profiler.reset()
@@ -47,6 +48,7 @@
 #define PROF_TRANSPOSITION(profiler)
 #define PROF_NEARESTKEY(profiler)
 #define PROF_TERMINAL(profiler)
+#define PROF_TERMINAL_INSERTION(profiler)
 #define PROF_NEW_WORD(profiler)
 #define PROF_NEW_WORD_BIGRAM(profiler)
 #define PROF_NODE_RESET(profiler)
@@ -62,7 +64,7 @@
             : mProfOmission(0), mProfInsertion(0), mProfTransposition(0),
               mProfAdditionalProximity(0), mProfSubstitution(0),
               mProfSpaceSubstitution(0), mProfSpaceOmission(0),
-              mProfMatch(0), mProfCompletion(0), mProfTerminal(0),
+              mProfMatch(0), mProfCompletion(0), mProfTerminal(0), mProfTerminalInsertion(0),
               mProfNearestKey(0), mProfNewWord(0), mProfNewWordBigram(0) {}
 
     int mProfOmission;
@@ -75,6 +77,7 @@
     int mProfMatch;
     int mProfCompletion;
     int mProfTerminal;
+    int mProfTerminalInsertion;
     int mProfNearestKey;
     int mProfNewWord;
     int mProfNewWordBigram;
@@ -123,6 +126,10 @@
         ++mProfTerminal;
     }
 
+    void profTerminalInsertion() {
+        ++mProfTerminalInsertion;
+    }
+
     void profNewWord() {
         ++mProfNewWord;
     }
diff --git a/native/jni/src/suggest/core/policy/weighting.cpp b/native/jni/src/suggest/core/policy/weighting.cpp
index 117f48f..5872922 100644
--- a/native/jni/src/suggest/core/policy/weighting.cpp
+++ b/native/jni/src/suggest/core/policy/weighting.cpp
@@ -50,6 +50,9 @@
     case CT_TERMINAL:
         PROF_TERMINAL(node->mProfiler);
         return;
+    case CT_TERMINAL_INSERTION:
+        PROF_TERMINAL_INSERTION(node->mProfiler);
+        return;
     case CT_NEW_WORD_SPACE_SUBSTITUTION:
         PROF_SPACE_SUBSTITUTION(node->mProfiler);
         return;
@@ -113,6 +116,8 @@
         return weighting->getCompletionCost(traverseSession, dicNode);
     case CT_TERMINAL:
         return weighting->getTerminalSpatialCost(traverseSession, dicNode);
+    case CT_TERMINAL_INSERTION:
+        return weighting->getTerminalInsertionCost(traverseSession, dicNode);
     case CT_NEW_WORD_SPACE_SUBSTITUTION:
         return weighting->getSpaceSubstitutionCost(traverseSession, dicNode);
     case CT_INSERTION:
@@ -146,6 +151,8 @@
                         traverseSession->getBinaryDictionaryInfo(), dicNode, multiBigramMap);
         return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
     }
+    case CT_TERMINAL_INSERTION:
+        return 0.0f;
     case CT_NEW_WORD_SPACE_SUBSTITUTION:
         return weighting->getNewWordBigramLanguageCost(
                 traverseSession, parentDicNode, multiBigramMap);
@@ -163,9 +170,9 @@
         case CT_OMISSION:
             return 0;
         case CT_ADDITIONAL_PROXIMITY:
-            return 0;
+            return 0; /* 0 because CT_MATCH will be called */
         case CT_SUBSTITUTION:
-            return 0;
+            return 0; /* 0 because CT_MATCH will be called */
         case CT_NEW_WORD_SPACE_OMITTION:
             return 0;
         case CT_MATCH:
@@ -174,12 +181,14 @@
             return 1;
         case CT_TERMINAL:
             return 0;
+        case CT_TERMINAL_INSERTION:
+            return 1;
         case CT_NEW_WORD_SPACE_SUBSTITUTION:
             return 1;
         case CT_INSERTION:
-            return 2;
+            return 2; /* look ahead + skip the current char */
         case CT_TRANSPOSITION:
-            return 2;
+            return 2; /* look ahead + skip the current char */
         default:
             return 0;
     }
diff --git a/native/jni/src/suggest/core/policy/weighting.h b/native/jni/src/suggest/core/policy/weighting.h
index 781a7ad..2d49e98 100644
--- a/native/jni/src/suggest/core/policy/weighting.h
+++ b/native/jni/src/suggest/core/policy/weighting.h
@@ -67,6 +67,10 @@
             const DicTraverseSession *const traverseSession,
             const DicNode *const dicNode) const = 0;
 
+    virtual float getTerminalInsertionCost(
+            const DicTraverseSession *const traverseSession,
+            const DicNode *const dicNode) const = 0;
+
     virtual float getTerminalLanguageCost(
             const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
             float dicNodeLanguageImprobability) const = 0;
diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp
index d6383b9..73e9714 100644
--- a/native/jni/src/suggest/core/suggest.cpp
+++ b/native/jni/src/suggest/core/suggest.cpp
@@ -365,17 +365,17 @@
     if (!dicNode->isTerminalWordNode()) {
         return;
     }
-    if (TRAVERSAL->needsToTraverseAllUserInput()
-            && dicNode->getInputIndex(0) < traverseSession->getInputSize()) {
-        return;
-    }
-
     if (dicNode->shouldBeFilterdBySafetyNetForBigram()) {
         return;
     }
     // Create a non-cached node here.
     DicNode terminalDicNode;
     DicNodeUtils::initByCopy(dicNode, &terminalDicNode);
+    if (TRAVERSAL->needsToTraverseAllUserInput()
+            && dicNode->getInputIndex(0) < traverseSession->getInputSize()) {
+        Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL_INSERTION, traverseSession, 0,
+                &terminalDicNode, traverseSession->getMultiBigramMap());
+    }
     Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0,
             &terminalDicNode, traverseSession->getMultiBigramMap());
     traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode);
diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
index a8f797c..4157f41 100644
--- a/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
+++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.cpp
@@ -34,6 +34,7 @@
 const float ScoringParams::OMISSION_COST_SAME_CHAR = 0.491f;
 const float ScoringParams::OMISSION_COST_FIRST_CHAR = 0.582f;
 const float ScoringParams::INSERTION_COST = 0.730f;
+const float ScoringParams::TERMINAL_INSERTION_COST = 0.93f;
 const float ScoringParams::INSERTION_COST_SAME_CHAR = 0.586f;
 const float ScoringParams::INSERTION_COST_PROXIMITY_CHAR = 0.70f;
 const float ScoringParams::INSERTION_COST_FIRST_CHAR = 0.623f;
diff --git a/native/jni/src/suggest/policyimpl/typing/scoring_params.h b/native/jni/src/suggest/policyimpl/typing/scoring_params.h
index 4ebcc7d..a743b4d 100644
--- a/native/jni/src/suggest/policyimpl/typing/scoring_params.h
+++ b/native/jni/src/suggest/policyimpl/typing/scoring_params.h
@@ -42,6 +42,7 @@
     static const float OMISSION_COST_SAME_CHAR;
     static const float OMISSION_COST_FIRST_CHAR;
     static const float INSERTION_COST;
+    static const float TERMINAL_INSERTION_COST;
     static const float INSERTION_COST_SAME_CHAR;
     static const float INSERTION_COST_PROXIMITY_CHAR;
     static const float INSERTION_COST_FIRST_CHAR;
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
index e4c69d1..408b12a 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.cpp
@@ -44,6 +44,7 @@
             break;
         case CT_SUBSTITUTION:
         case CT_INSERTION:
+        case CT_TERMINAL_INSERTION:
         case CT_TRANSPOSITION:
             return ET_EDIT_CORRECTION;
         case CT_NEW_WORD_SPACE_OMITTION:
diff --git a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
index 1bb1607..7cddb08 100644
--- a/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
+++ b/native/jni/src/suggest/policyimpl/typing/typing_weighting.h
@@ -175,6 +175,15 @@
         return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
     }
 
+    float getTerminalInsertionCost(const DicTraverseSession *const traverseSession,
+            const DicNode *const dicNode) const {
+        const int inputIndex = dicNode->getInputIndex(0);
+        const int inputSize = traverseSession->getInputSize();
+        ASSERT(inputIndex < inputSize);
+        // TODO: Implement more efficient logic
+        return  ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex);
+    }
+
     AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const {
         return false;
     }