Move the input index and output index to correction state

Change-Id: Idebdb59143f3367929df6a0475cefe941eb16d01
diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp
index f5648d3..9f8f04e 100644
--- a/native/src/unigram_dictionary.cpp
+++ b/native/src/unigram_dictionary.cpp
@@ -363,27 +363,25 @@
     mStackSiblingPos[0] = rootPosition;
     mStackOutputIndex[0] = 0;
     mStackMatchedCount[0] = 0;
-    mCorrectionState->initDepth();
 
     // Depth first search
     while (depth >= 0) {
         if (mStackChildCount[depth] > 0) {
             --mStackChildCount[depth];
             bool traverseAllNodes = mStackTraverseAll[depth];
-            int inputIndex = mStackInputIndex[depth];
             int diffs = mStackDiffs[depth];
             int siblingPos = mStackSiblingPos[depth];
-            int outputIndex = mStackOutputIndex[depth];
             int firstChildPos;
-            mCorrectionState->slideTree(mStackMatchedCount[depth]);
+            mCorrectionState->initProcessState(
+                    mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth]);
 
             // depth will never be greater than maxDepth because in that case,
             // needsToTraverseChildrenNodes should be false
-            const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, outputIndex,
-                    maxDepth, traverseAllNodes, inputIndex, diffs,
+            const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos,
+                    maxDepth, traverseAllNodes, diffs,
                     mCorrectionState, &childCount,
-                    &firstChildPos, &traverseAllNodes, &inputIndex, &diffs,
-                    &siblingPos, &outputIndex);
+                    &firstChildPos, &traverseAllNodes, &diffs,
+                    &siblingPos);
             // Update next sibling pos
             mStackSiblingPos[depth] = siblingPos;
             if (needsToTraverseChildrenNodes) {
@@ -391,21 +389,15 @@
                 ++depth;
                 mStackChildCount[depth] = childCount;
                 mStackTraverseAll[depth] = traverseAllNodes;
-                mStackInputIndex[depth] = inputIndex;
                 mStackDiffs[depth] = diffs;
                 mStackSiblingPos[depth] = firstChildPos;
-                mStackOutputIndex[depth] = outputIndex;
 
-                int matchedCount;
-                mCorrectionState->goDownTree(&matchedCount);
-                mStackMatchedCount[depth] = matchedCount;
-            } else {
-                mCorrectionState->slideTree(mStackMatchedCount[depth]);
+                mCorrectionState->getProcessState(&mStackMatchedCount[depth],
+                        &mStackInputIndex[depth], &mStackOutputIndex[depth]);
             }
         } else {
             // Goes to parent sibling node
             --depth;
-            mCorrectionState->goUpTree(mStackMatchedCount[depth]);
         }
     }
 }
@@ -446,13 +438,11 @@
 }
 
 
-inline void UnigramDictionary::onTerminal(unsigned short int* word, const int outputIndex,
-        const int inputIndex, const int freq, CorrectionState *correctionState) {
-    if (!mProximityInfo->sameAsTyped(word, outputIndex + 1) && outputIndex >= MIN_SUGGEST_DEPTH) {
-        const int finalFreq = correctionState->getFinalFreq(inputIndex, outputIndex, freq);
-        if (finalFreq >= 0) {
-            addWord(word, outputIndex + 1, finalFreq);
-        }
+inline void UnigramDictionary::onTerminal(
+        unsigned short int* word, const int freq, CorrectionState *correctionState) {
+    const int finalFreq = correctionState->getFinalFreq(word, freq);
+    if (finalFreq >= 0) {
+        addWord(word, correctionState->getOutputIndex() + 1, finalFreq);
     }
 }
 
@@ -667,12 +657,10 @@
 // there aren't any more nodes at this level, it merely returns the address of the first byte after
 // the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any
 // given level, as output into newCount when traversing this level's parent.
-inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialOutputPos,
-        const int maxDepth, const bool initialTraverseAllNodes, int inputIndex,
-        const int initialDiffs,
+inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int maxDepth,
+        const bool initialTraverseAllNodes, const int initialDiffs,
         CorrectionState *correctionState, int *newCount, int *newChildrenPosition,
-        bool *newTraverseAllNodes, int *newInputIndex, int *newDiffs,
-        int *nextSiblingPosition, int *newOutputIndex) {
+        bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition) {
     const int skipPos = correctionState->getSkipPos();
     const int excessivePos = correctionState->getExcessivePos();
     const int transposedPos = correctionState->getTransposedPos();
@@ -680,9 +668,9 @@
         correctionState->checkState();
     }
     int pos = initialPos;
-    int internalOutputPos = initialOutputPos;
     int traverseAllNodes = initialTraverseAllNodes;
     int diffs = initialDiffs;
+    const int initialInputIndex = correctionState->getInputIndex();
 
     // Flags contain the following information:
     // - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits:
@@ -726,16 +714,18 @@
 
         // This has to be done for each virtual char (this forwards the "inputIndex" which
         // is the index in the user-inputted chars, as read by proximity chars.
-        if (excessivePos == internalOutputPos && inputIndex < mInputLength - 1) {
-            ++inputIndex;
+        if (excessivePos == correctionState->getOutputIndex()
+                && correctionState->getInputIndex() < mInputLength - 1) {
+            correctionState->incrementInputIndex();
         }
-        if (traverseAllNodes || needsToSkipCurrentNode(c, inputIndex, skipPos, internalOutputPos)) {
-            mWord[internalOutputPos] = c;
+        if (traverseAllNodes || needsToSkipCurrentNode(
+                c, correctionState->getInputIndex(), skipPos, correctionState->getOutputIndex())) {
+            mWord[correctionState->getOutputIndex()] = c;
             if (traverseAllNodes && isTerminal) {
                 // The frequency should be here, because we come here only if this is actually
                 // a terminal node, and we are on its last char.
                 const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
-                onTerminal(mWord, internalOutputPos, inputIndex, freq, mCorrectionState);
+                onTerminal(mWord, freq, mCorrectionState);
             }
             if (!hasChildren) {
                 // If we don't have children here, that means we finished processing all
@@ -750,11 +740,15 @@
                 return false;
             }
         } else {
-            int inputIndexForProximity = inputIndex;
+            int inputIndexForProximity = correctionState->getInputIndex();
 
             if (transposedPos >= 0) {
-                if (inputIndex == transposedPos) ++inputIndexForProximity;
-                if (inputIndex == (transposedPos + 1)) --inputIndexForProximity;
+                if (correctionState->getInputIndex() == transposedPos) {
+                    ++inputIndexForProximity;
+                }
+                if (correctionState->getInputIndex() == (transposedPos + 1)) {
+                    --inputIndexForProximity;
+                }
             }
 
             int matchedProximityCharId = mProximityInfo->getMatchedProximityId(
@@ -775,18 +769,31 @@
                         BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
                 return false;
             }
-            mWord[internalOutputPos] = c;
+            mWord[correctionState->getOutputIndex()] = c;
             // If inputIndex is greater than mInputLength, that means there is no
             // proximity chars. So, we don't need to check proximity.
             if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) {
                 correctionState->charMatched();
             }
-            const bool isSameAsUserTypedLength = mInputLength == inputIndex + 1
-                    || (excessivePos == mInputLength - 1 && inputIndex == mInputLength - 2);
+            const bool isSameAsUserTypedLength = mInputLength
+                    == correctionState->getInputIndex() + 1
+                            || (excessivePos == mInputLength - 1
+                                        && correctionState->getInputIndex() == mInputLength - 2);
             if (isSameAsUserTypedLength && isTerminal) {
                 const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
-                onTerminal(mWord, internalOutputPos, inputIndex, freq, mCorrectionState);
+                onTerminal(mWord, freq, mCorrectionState);
             }
+            // Start traversing all nodes after the index exceeds the user typed length
+            traverseAllNodes = isSameAsUserTypedLength;
+            diffs = diffs
+                    + ((ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) ? 1 : 0);
+            // Finally, we are ready to go to the next character, the next "virtual node".
+            // We should advance the input index.
+            // We do this in this branch of the 'if traverseAllNodes' because we are still matching
+            // characters to input; the other branch is not matching them but searching for
+            // completions, this is why it does not have to do it.
+            correctionState->incrementInputIndex();
+
             // This character matched the typed character (enough to traverse the node at least)
             // so we just evaluated it. Now we should evaluate this virtual node's children - that
             // is, if it has any. If it has no children, we're done here - so we skip the end of
@@ -799,19 +806,9 @@
                         BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
                 return false;
             }
-            // Start traversing all nodes after the index exceeds the user typed length
-            traverseAllNodes = isSameAsUserTypedLength;
-            diffs = diffs
-                    + ((ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) ? 1 : 0);
-            // Finally, we are ready to go to the next character, the next "virtual node".
-            // We should advance the input index.
-            // We do this in this branch of the 'if traverseAllNodes' because we are still matching
-            // characters to input; the other branch is not matching them but searching for
-            // completions, this is why it does not have to do it.
-            ++inputIndex;
         }
         // Optimization: Prune out words that are too long compared to how much was typed.
-        if (internalOutputPos >= maxDepth || diffs > mMaxEditDistance) {
+        if (correctionState->getOutputIndex() >= maxDepth || diffs > mMaxEditDistance) {
             // We are giving up parsing this node and its children. Skip the rest of the node,
             // output the sibling position, and return that we don't want to traverse children.
             if (!isLastChar) {
@@ -822,18 +819,18 @@
                     BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
             return false;
         }
+        // Also, the next char is one "virtual node" depth more than this char.
+        correctionState->incrementOutputIndex();
 
         // Prepare for the next character. Promote the prefetched char to current char - the loop
         // will take care of prefetching the next. If we finally found our last char, nextc will
         // contain NOT_A_CHARACTER.
         c = nextc;
-        // Also, the next char is one "virtual node" depth more than this char.
-        ++internalOutputPos;
     } while (NOT_A_CHARACTER != c);
 
     // If inputIndex is greater than mInputLength, that means there are no proximity chars.
     // Here, that's all we are interested in so we don't need to check for isSameAsUserTypedLength.
-    if (mInputLength <= *newInputIndex) {
+    if (mInputLength <= initialInputIndex) {
         traverseAllNodes = true;
     }
 
@@ -841,8 +838,6 @@
     // variables. Output them to the caller.
     *newTraverseAllNodes = traverseAllNodes;
     *newDiffs = diffs;
-    *newInputIndex = inputIndex;
-    *newOutputIndex = internalOutputPos;
 
     // Now we finished processing this node, and we want to traverse children. If there are no
     // children, we can't come here.