Move code related to ranking algorithm to correction_state.cpp

Change-Id: I52b34de45969fef82e46d9c10079c2d45e0b94eb
diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp
index eb28538..f5648d3 100644
--- a/native/src/unigram_dictionary.cpp
+++ b/native/src/unigram_dictionary.cpp
@@ -167,12 +167,6 @@
             LOGI("%s %i", s, mFrequencies[j]);
 #endif
         }
-        LOGI("Next letters: ");
-        for (int k = 0; k < NEXT_LETTERS_SIZE; k++) {
-            if (mNextLettersFrequency[k] > 0) {
-                LOGI("%c = %d,", k, mNextLettersFrequency[k]);
-            }
-        }
     }
     PROF_END(20);
     PROF_CLOSE;
@@ -194,7 +188,7 @@
     PROF_END(0);
 
     PROF_START(1);
-    getSuggestionCandidates(-1, -1, -1, mNextLettersFrequency, NEXT_LETTERS_SIZE, MAX_DEPTH);
+    getSuggestionCandidates(-1, -1, -1, MAX_DEPTH);
     PROF_END(1);
 
     PROF_START(2);
@@ -204,7 +198,7 @@
             if (DEBUG_DICT) {
                 LOGI("--- Suggest missing characters %d", i);
             }
-            getSuggestionCandidates(i, -1, -1, NULL, 0, MAX_DEPTH);
+            getSuggestionCandidates(i, -1, -1, MAX_DEPTH);
         }
     }
     PROF_END(2);
@@ -217,7 +211,7 @@
             if (DEBUG_DICT) {
                 LOGI("--- Suggest excessive characters %d", i);
             }
-            getSuggestionCandidates(-1, i, -1, NULL, 0, MAX_DEPTH);
+            getSuggestionCandidates(-1, i, -1, MAX_DEPTH);
         }
     }
     PROF_END(3);
@@ -230,7 +224,7 @@
             if (DEBUG_DICT) {
                 LOGI("--- Suggest transposed characters %d", i);
             }
-            getSuggestionCandidates(-1, -1, i, NULL, 0, mInputLength - 1);
+            getSuggestionCandidates(-1, -1, i, mInputLength - 1);
         }
     }
     PROF_END(4);
@@ -348,8 +342,7 @@
 static const char SPACE = ' ';
 
 void UnigramDictionary::getSuggestionCandidates(const int skipPos,
-        const int excessivePos, const int transposedPos, int *nextLetters,
-        const int nextLettersSize, const int maxDepth) {
+        const int excessivePos, const int transposedPos, const int maxDepth) {
     if (DEBUG_DICT) {
         LOGI("getSuggestionCandidates %d", maxDepth);
         assert(transposedPos + 1 < mInputLength);
@@ -365,29 +358,31 @@
 
     mStackChildCount[0] = childCount;
     mStackTraverseAll[0] = (mInputLength <= 0);
-    mStackMatchCount[0] = 0;
     mStackInputIndex[0] = 0;
     mStackDiffs[0] = 0;
     mStackSiblingPos[0] = rootPosition;
     mStackOutputIndex[0] = 0;
+    mStackMatchedCount[0] = 0;
+    mCorrectionState->initDepth();
 
     // Depth first search
     while (depth >= 0) {
         if (mStackChildCount[depth] > 0) {
             --mStackChildCount[depth];
             bool traverseAllNodes = mStackTraverseAll[depth];
-            int matchCount = mStackMatchCount[depth];
             int inputIndex = mStackInputIndex[depth];
             int diffs = mStackDiffs[depth];
             int siblingPos = mStackSiblingPos[depth];
             int outputIndex = mStackOutputIndex[depth];
             int firstChildPos;
+            mCorrectionState->slideTree(mStackMatchedCount[depth]);
+
             // depth will never be greater than maxDepth because in that case,
             // needsToTraverseChildrenNodes should be false
             const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, outputIndex,
-                    maxDepth, traverseAllNodes, matchCount, inputIndex, diffs,
-                    nextLetters, nextLettersSize, mCorrectionState, &childCount,
-                    &firstChildPos, &traverseAllNodes, &matchCount, &inputIndex, &diffs,
+                    maxDepth, traverseAllNodes, inputIndex, diffs,
+                    mCorrectionState, &childCount,
+                    &firstChildPos, &traverseAllNodes, &inputIndex, &diffs,
                     &siblingPos, &outputIndex);
             // Update next sibling pos
             mStackSiblingPos[depth] = siblingPos;
@@ -396,15 +391,21 @@
                 ++depth;
                 mStackChildCount[depth] = childCount;
                 mStackTraverseAll[depth] = traverseAllNodes;
-                mStackMatchCount[depth] = matchCount;
                 mStackInputIndex[depth] = inputIndex;
                 mStackDiffs[depth] = diffs;
                 mStackSiblingPos[depth] = firstChildPos;
                 mStackOutputIndex[depth] = outputIndex;
+
+                int matchedCount;
+                mCorrectionState->goDownTree(&matchedCount);
+                mStackMatchedCount[depth] = matchedCount;
+            } else {
+                mCorrectionState->slideTree(mStackMatchedCount[depth]);
             }
         } else {
             // Goes to parent sibling node
             --depth;
+            mCorrectionState->goUpTree(mStackMatchedCount[depth]);
         }
     }
 }
@@ -445,24 +446,13 @@
 }
 
 
-inline void UnigramDictionary::onTerminal(unsigned short int* word, const int depth,
-        const uint8_t* const root, const uint8_t flags, const int pos,
-        const int inputIndex, const int matchCount, const int freq, const bool sameLength,
-        int* nextLetters, const int nextLettersSize, CorrectionState *correctionState) {
-    const int skipPos = correctionState->getSkipPos();
-
-    const bool isSameAsTyped = sameLength ? mProximityInfo->sameAsTyped(word, depth + 1) : false;
-    if (isSameAsTyped) return;
-
-    if (depth >= MIN_SUGGEST_DEPTH) {
-        const int finalFreq = correctionState->getFinalFreq(inputIndex, depth, matchCount,
-                freq, sameLength);
-        if (!isSameAsTyped)
-            addWord(word, depth + 1, finalFreq);
-    }
-
-    if (sameLength && depth >= mInputLength && skipPos < 0) {
-        registerNextLetter(word[mInputLength], nextLetters, nextLettersSize);
+inline void UnigramDictionary::onTerminal(unsigned short int* word, const int outputIndex,
+        const int inputIndex, const int freq, CorrectionState *correctionState) {
+    if (!mProximityInfo->sameAsTyped(word, outputIndex + 1) && outputIndex >= MIN_SUGGEST_DEPTH) {
+        const int finalFreq = correctionState->getFinalFreq(inputIndex, outputIndex, freq);
+        if (finalFreq >= 0) {
+            addWord(word, outputIndex + 1, finalFreq);
+        }
     }
 }
 
@@ -677,11 +667,11 @@
 // there aren't any more nodes at this level, it merely returns the address of the first byte after
 // the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any
 // given level, as output into newCount when traversing this level's parent.
-inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialDepth,
-        const int maxDepth, const bool initialTraverseAllNodes, int matchCount, int inputIndex,
-        const int initialDiffs, int *nextLetters, const int nextLettersSize,
+inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialOutputPos,
+        const int maxDepth, const bool initialTraverseAllNodes, int inputIndex,
+        const int initialDiffs,
         CorrectionState *correctionState, int *newCount, int *newChildrenPosition,
-        bool *newTraverseAllNodes, int *newMatchRate, int *newInputIndex, int *newDiffs,
+        bool *newTraverseAllNodes, int *newInputIndex, int *newDiffs,
         int *nextSiblingPosition, int *newOutputIndex) {
     const int skipPos = correctionState->getSkipPos();
     const int excessivePos = correctionState->getExcessivePos();
@@ -690,7 +680,7 @@
         correctionState->checkState();
     }
     int pos = initialPos;
-    int depth = initialDepth;
+    int internalOutputPos = initialOutputPos;
     int traverseAllNodes = initialTraverseAllNodes;
     int diffs = initialDiffs;
 
@@ -736,15 +726,16 @@
 
         // This has to be done for each virtual char (this forwards the "inputIndex" which
         // is the index in the user-inputted chars, as read by proximity chars.
-        if (excessivePos == depth && inputIndex < mInputLength - 1) ++inputIndex;
-        if (traverseAllNodes || needsToSkipCurrentNode(c, inputIndex, skipPos, depth)) {
-            mWord[depth] = c;
+        if (excessivePos == internalOutputPos && inputIndex < mInputLength - 1) {
+            ++inputIndex;
+        }
+        if (traverseAllNodes || needsToSkipCurrentNode(c, inputIndex, skipPos, internalOutputPos)) {
+            mWord[internalOutputPos] = c;
             if (traverseAllNodes && isTerminal) {
                 // The frequency should be here, because we come here only if this is actually
                 // a terminal node, and we are on its last char.
                 const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
-                onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchCount,
-                        freq, false, nextLetters, nextLettersSize, mCorrectionState);
+                onTerminal(mWord, internalOutputPos, inputIndex, freq, mCorrectionState);
             }
             if (!hasChildren) {
                 // If we don't have children here, that means we finished processing all
@@ -784,18 +775,17 @@
                         BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
                 return false;
             }
-            mWord[depth] = c;
+            mWord[internalOutputPos] = c;
             // If inputIndex is greater than mInputLength, that means there is no
             // proximity chars. So, we don't need to check proximity.
             if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) {
-                ++matchCount;
+                correctionState->charMatched();
             }
             const bool isSameAsUserTypedLength = mInputLength == inputIndex + 1
                     || (excessivePos == mInputLength - 1 && inputIndex == mInputLength - 2);
             if (isSameAsUserTypedLength && isTerminal) {
                 const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
-                onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchCount,
-                        freq, true, nextLetters, nextLettersSize, mCorrectionState);
+                onTerminal(mWord, internalOutputPos, inputIndex, freq, mCorrectionState);
             }
             // This character matched the typed character (enough to traverse the node at least)
             // so we just evaluated it. Now we should evaluate this virtual node's children - that
@@ -821,7 +811,7 @@
             ++inputIndex;
         }
         // Optimization: Prune out words that are too long compared to how much was typed.
-        if (depth >= maxDepth || diffs > mMaxEditDistance) {
+        if (internalOutputPos >= maxDepth || diffs > mMaxEditDistance) {
             // We are giving up parsing this node and its children. Skip the rest of the node,
             // output the sibling position, and return that we don't want to traverse children.
             if (!isLastChar) {
@@ -838,7 +828,7 @@
         // contain NOT_A_CHARACTER.
         c = nextc;
         // Also, the next char is one "virtual node" depth more than this char.
-        ++depth;
+        ++internalOutputPos;
     } while (NOT_A_CHARACTER != c);
 
     // If inputIndex is greater than mInputLength, that means there are no proximity chars.
@@ -850,10 +840,9 @@
     // All the output values that are purely computation by this function are held in local
     // variables. Output them to the caller.
     *newTraverseAllNodes = traverseAllNodes;
-    *newMatchRate = matchCount;
     *newDiffs = diffs;
     *newInputIndex = inputIndex;
-    *newOutputIndex = depth;
+    *newOutputIndex = internalOutputPos;
 
     // Now we finished processing this node, and we want to traverse children. If there are no
     // children, we can't come here.