Move the input index and output index to correction state

Change-Id: Idebdb59143f3367929df6a0475cefe941eb16d01
diff --git a/native/src/correction_state.cpp b/native/src/correction_state.cpp
index add6cf6..b2c77b0 100644
--- a/native/src/correction_state.cpp
+++ b/native/src/correction_state.cpp
@@ -58,32 +58,49 @@
     return CorrectionState::RankingAlgorithm::calcFreqForSplitTwoWords(firstFreq, secondFreq, this);
 }
 
-int CorrectionState::getFinalFreq(const int inputIndex, const int outputIndex, const int freq) {
-    const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2)
-            : (mInputLength == inputIndex + 1);
-    const int matchCount = mMatchedCharCount;
+int CorrectionState::getFinalFreq(const unsigned short *word, const int freq) {
+    if (mProximityInfo->sameAsTyped(word, mOutputIndex + 1) || mOutputIndex < MIN_SUGGEST_DEPTH) {
+        return -1;
+    }
+    const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == mInputIndex + 2)
+            : (mInputLength == mInputIndex + 1);
     return CorrectionState::RankingAlgorithm::calculateFinalFreq(
-            inputIndex, outputIndex, matchCount, freq, sameLength, this);
+            mInputIndex, mOutputIndex, mMatchedCharCount, freq, sameLength, this);
 }
 
-void CorrectionState::initDepth() {
-    mMatchedCharCount = 0;
+void CorrectionState::initProcessState(
+        const int matchCount, const int inputIndex, const int outputIndex) {
+    mMatchedCharCount = matchCount;
+    mInputIndex = inputIndex;
+    mOutputIndex = outputIndex;
+}
+
+void CorrectionState::getProcessState(int *matchedCount, int *inputIndex, int *outputIndex) {
+    *matchedCount = mMatchedCharCount;
+    *inputIndex = mInputIndex;
+    *outputIndex = mOutputIndex;
 }
 
 void CorrectionState::charMatched() {
     ++mMatchedCharCount;
 }
 
-void CorrectionState::goUpTree(const int matchCount) {
-    mMatchedCharCount = matchCount;
+// TODO: remove
+int CorrectionState::getOutputIndex() {
+    return mOutputIndex;
 }
 
-void CorrectionState::slideTree(const int matchCount) {
-    mMatchedCharCount = matchCount;
+// TODO: remove
+int CorrectionState::getInputIndex() {
+    return mInputIndex;
 }
 
-void CorrectionState::goDownTree(int *matchedCount) {
-    *matchedCount = mMatchedCharCount;
+void CorrectionState::incrementInputIndex() {
+    ++mInputIndex;
+}
+
+void CorrectionState::incrementOutputIndex() {
+    ++mOutputIndex;
 }
 
 CorrectionState::~CorrectionState() {
diff --git a/native/src/correction_state.h b/native/src/correction_state.h
index 7bbad5f..cc3c3e6 100644
--- a/native/src/correction_state.h
+++ b/native/src/correction_state.h
@@ -28,16 +28,25 @@
 class CorrectionState {
 
 public:
+    typedef enum {
+        ALLOW_ALL,
+        UNRELATED,
+        RELATED
+    } CorrectionStateType;
+
     CorrectionState(const int typedLetterMultiplier, const int fullWordMultiplier);
     void initCorrectionState(const ProximityInfo *pi, const int inputLength);
     void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos,
             const int spaceProximityPos, const int missingSpacePos);
-    void initDepth();
     void checkState();
-    void goUpTree(const int matchCount);
-    void slideTree(const int matchCount);
-    void goDownTree(int *matchedCount);
+    void initProcessState(const int matchCount, const int inputIndex, const int outputIndex);
+    void getProcessState(int *matchedCount, int *inputIndex, int *outputIndex);
     void charMatched();
+    void incrementInputIndex();
+    void incrementOutputIndex();
+    int getOutputIndex();
+    int getInputIndex();
+
     virtual ~CorrectionState();
     int getSkipPos() const {
         return mSkipPos;
@@ -55,7 +64,7 @@
         return mMissingSpacePos;
     }
     int getFreqForSplitTwoWords(const int firstFreq, const int secondFreq);
-    int getFinalFreq(const int inputIndex, const int outputIndex, const int freq);
+    int getFinalFreq(const unsigned short *word, const int freq);
 
 private:
 
@@ -71,6 +80,8 @@
     int mMissingSpacePos;
 
     int mMatchedCharCount;
+    int mInputIndex;
+    int mOutputIndex;
 
     class RankingAlgorithm {
     public:
diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp
index f5648d3..9f8f04e 100644
--- a/native/src/unigram_dictionary.cpp
+++ b/native/src/unigram_dictionary.cpp
@@ -363,27 +363,25 @@
     mStackSiblingPos[0] = rootPosition;
     mStackOutputIndex[0] = 0;
     mStackMatchedCount[0] = 0;
-    mCorrectionState->initDepth();
 
     // Depth first search
     while (depth >= 0) {
         if (mStackChildCount[depth] > 0) {
             --mStackChildCount[depth];
             bool traverseAllNodes = mStackTraverseAll[depth];
-            int inputIndex = mStackInputIndex[depth];
             int diffs = mStackDiffs[depth];
             int siblingPos = mStackSiblingPos[depth];
-            int outputIndex = mStackOutputIndex[depth];
             int firstChildPos;
-            mCorrectionState->slideTree(mStackMatchedCount[depth]);
+            mCorrectionState->initProcessState(
+                    mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth]);
 
             // depth will never be greater than maxDepth because in that case,
             // needsToTraverseChildrenNodes should be false
-            const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, outputIndex,
-                    maxDepth, traverseAllNodes, inputIndex, diffs,
+            const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos,
+                    maxDepth, traverseAllNodes, diffs,
                     mCorrectionState, &childCount,
-                    &firstChildPos, &traverseAllNodes, &inputIndex, &diffs,
-                    &siblingPos, &outputIndex);
+                    &firstChildPos, &traverseAllNodes, &diffs,
+                    &siblingPos);
             // Update next sibling pos
             mStackSiblingPos[depth] = siblingPos;
             if (needsToTraverseChildrenNodes) {
@@ -391,21 +389,15 @@
                 ++depth;
                 mStackChildCount[depth] = childCount;
                 mStackTraverseAll[depth] = traverseAllNodes;
-                mStackInputIndex[depth] = inputIndex;
                 mStackDiffs[depth] = diffs;
                 mStackSiblingPos[depth] = firstChildPos;
-                mStackOutputIndex[depth] = outputIndex;
 
-                int matchedCount;
-                mCorrectionState->goDownTree(&matchedCount);
-                mStackMatchedCount[depth] = matchedCount;
-            } else {
-                mCorrectionState->slideTree(mStackMatchedCount[depth]);
+                mCorrectionState->getProcessState(&mStackMatchedCount[depth],
+                        &mStackInputIndex[depth], &mStackOutputIndex[depth]);
             }
         } else {
             // Goes to parent sibling node
             --depth;
-            mCorrectionState->goUpTree(mStackMatchedCount[depth]);
         }
     }
 }
@@ -446,13 +438,11 @@
 }
 
 
-inline void UnigramDictionary::onTerminal(unsigned short int* word, const int outputIndex,
-        const int inputIndex, const int freq, CorrectionState *correctionState) {
-    if (!mProximityInfo->sameAsTyped(word, outputIndex + 1) && outputIndex >= MIN_SUGGEST_DEPTH) {
-        const int finalFreq = correctionState->getFinalFreq(inputIndex, outputIndex, freq);
-        if (finalFreq >= 0) {
-            addWord(word, outputIndex + 1, finalFreq);
-        }
+inline void UnigramDictionary::onTerminal(
+        unsigned short int* word, const int freq, CorrectionState *correctionState) {
+    const int finalFreq = correctionState->getFinalFreq(word, freq);
+    if (finalFreq >= 0) {
+        addWord(word, correctionState->getOutputIndex() + 1, finalFreq);
     }
 }
 
@@ -667,12 +657,10 @@
 // there aren't any more nodes at this level, it merely returns the address of the first byte after
 // the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any
 // given level, as output into newCount when traversing this level's parent.
-inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialOutputPos,
-        const int maxDepth, const bool initialTraverseAllNodes, int inputIndex,
-        const int initialDiffs,
+inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int maxDepth,
+        const bool initialTraverseAllNodes, const int initialDiffs,
         CorrectionState *correctionState, int *newCount, int *newChildrenPosition,
-        bool *newTraverseAllNodes, int *newInputIndex, int *newDiffs,
-        int *nextSiblingPosition, int *newOutputIndex) {
+        bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition) {
     const int skipPos = correctionState->getSkipPos();
     const int excessivePos = correctionState->getExcessivePos();
     const int transposedPos = correctionState->getTransposedPos();
@@ -680,9 +668,9 @@
         correctionState->checkState();
     }
     int pos = initialPos;
-    int internalOutputPos = initialOutputPos;
     int traverseAllNodes = initialTraverseAllNodes;
     int diffs = initialDiffs;
+    const int initialInputIndex = correctionState->getInputIndex();
 
     // Flags contain the following information:
     // - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits:
@@ -726,16 +714,18 @@
 
         // This has to be done for each virtual char (this forwards the "inputIndex" which
         // is the index in the user-inputted chars, as read by proximity chars.
-        if (excessivePos == internalOutputPos && inputIndex < mInputLength - 1) {
-            ++inputIndex;
+        if (excessivePos == correctionState->getOutputIndex()
+                && correctionState->getInputIndex() < mInputLength - 1) {
+            correctionState->incrementInputIndex();
         }
-        if (traverseAllNodes || needsToSkipCurrentNode(c, inputIndex, skipPos, internalOutputPos)) {
-            mWord[internalOutputPos] = c;
+        if (traverseAllNodes || needsToSkipCurrentNode(
+                c, correctionState->getInputIndex(), skipPos, correctionState->getOutputIndex())) {
+            mWord[correctionState->getOutputIndex()] = c;
             if (traverseAllNodes && isTerminal) {
                 // The frequency should be here, because we come here only if this is actually
                 // a terminal node, and we are on its last char.
                 const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
-                onTerminal(mWord, internalOutputPos, inputIndex, freq, mCorrectionState);
+                onTerminal(mWord, freq, mCorrectionState);
             }
             if (!hasChildren) {
                 // If we don't have children here, that means we finished processing all
@@ -750,11 +740,15 @@
                 return false;
             }
         } else {
-            int inputIndexForProximity = inputIndex;
+            int inputIndexForProximity = correctionState->getInputIndex();
 
             if (transposedPos >= 0) {
-                if (inputIndex == transposedPos) ++inputIndexForProximity;
-                if (inputIndex == (transposedPos + 1)) --inputIndexForProximity;
+                if (correctionState->getInputIndex() == transposedPos) {
+                    ++inputIndexForProximity;
+                }
+                if (correctionState->getInputIndex() == (transposedPos + 1)) {
+                    --inputIndexForProximity;
+                }
             }
 
             int matchedProximityCharId = mProximityInfo->getMatchedProximityId(
@@ -775,18 +769,31 @@
                         BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
                 return false;
             }
-            mWord[internalOutputPos] = c;
+            mWord[correctionState->getOutputIndex()] = c;
             // If inputIndex is greater than mInputLength, that means there is no
             // proximity chars. So, we don't need to check proximity.
             if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) {
                 correctionState->charMatched();
             }
-            const bool isSameAsUserTypedLength = mInputLength == inputIndex + 1
-                    || (excessivePos == mInputLength - 1 && inputIndex == mInputLength - 2);
+            const bool isSameAsUserTypedLength = mInputLength
+                    == correctionState->getInputIndex() + 1
+                            || (excessivePos == mInputLength - 1
+                                        && correctionState->getInputIndex() == mInputLength - 2);
             if (isSameAsUserTypedLength && isTerminal) {
                 const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
-                onTerminal(mWord, internalOutputPos, inputIndex, freq, mCorrectionState);
+                onTerminal(mWord, freq, mCorrectionState);
             }
+            // Start traversing all nodes after the index exceeds the user typed length
+            traverseAllNodes = isSameAsUserTypedLength;
+            diffs = diffs
+                    + ((ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) ? 1 : 0);
+            // Finally, we are ready to go to the next character, the next "virtual node".
+            // We should advance the input index.
+            // We do this in this branch of the 'if traverseAllNodes' because we are still matching
+            // characters to input; the other branch is not matching them but searching for
+            // completions, this is why it does not have to do it.
+            correctionState->incrementInputIndex();
+
             // This character matched the typed character (enough to traverse the node at least)
             // so we just evaluated it. Now we should evaluate this virtual node's children - that
             // is, if it has any. If it has no children, we're done here - so we skip the end of
@@ -799,19 +806,9 @@
                         BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
                 return false;
             }
-            // Start traversing all nodes after the index exceeds the user typed length
-            traverseAllNodes = isSameAsUserTypedLength;
-            diffs = diffs
-                    + ((ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) ? 1 : 0);
-            // Finally, we are ready to go to the next character, the next "virtual node".
-            // We should advance the input index.
-            // We do this in this branch of the 'if traverseAllNodes' because we are still matching
-            // characters to input; the other branch is not matching them but searching for
-            // completions, this is why it does not have to do it.
-            ++inputIndex;
         }
         // Optimization: Prune out words that are too long compared to how much was typed.
-        if (internalOutputPos >= maxDepth || diffs > mMaxEditDistance) {
+        if (correctionState->getOutputIndex() >= maxDepth || diffs > mMaxEditDistance) {
             // We are giving up parsing this node and its children. Skip the rest of the node,
             // output the sibling position, and return that we don't want to traverse children.
             if (!isLastChar) {
@@ -822,18 +819,18 @@
                     BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
             return false;
         }
+        // Also, the next char is one "virtual node" depth more than this char.
+        correctionState->incrementOutputIndex();
 
         // Prepare for the next character. Promote the prefetched char to current char - the loop
         // will take care of prefetching the next. If we finally found our last char, nextc will
         // contain NOT_A_CHARACTER.
         c = nextc;
-        // Also, the next char is one "virtual node" depth more than this char.
-        ++internalOutputPos;
     } while (NOT_A_CHARACTER != c);
 
     // If inputIndex is greater than mInputLength, that means there are no proximity chars.
     // Here, that's all we are interested in so we don't need to check for isSameAsUserTypedLength.
-    if (mInputLength <= *newInputIndex) {
+    if (mInputLength <= initialInputIndex) {
         traverseAllNodes = true;
     }
 
@@ -841,8 +838,6 @@
     // variables. Output them to the caller.
     *newTraverseAllNodes = traverseAllNodes;
     *newDiffs = diffs;
-    *newInputIndex = inputIndex;
-    *newOutputIndex = internalOutputPos;
 
     // Now we finished processing this node, and we want to traverse children. If there are no
     // children, we can't come here.
diff --git a/native/src/unigram_dictionary.h b/native/src/unigram_dictionary.h
index c67eaf6..cb86da4 100644
--- a/native/src/unigram_dictionary.h
+++ b/native/src/unigram_dictionary.h
@@ -94,18 +94,14 @@
             const int inputLength, const int missingSpacePos, CorrectionState *correctionState);
     void getMistypedSpaceWords(
             const int inputLength, const int spaceProximityPos, CorrectionState *correctionState);
-    void onTerminal(unsigned short int* word, const int depth,
-            const int inputIndex, const int freq,
-            CorrectionState *correctionState);
+    void onTerminal(unsigned short int* word, const int freq, CorrectionState *correctionState);
     bool needsToSkipCurrentNode(const unsigned short c,
             const int inputIndex, const int skipPos, const int depth);
     // Process a node by considering proximity, missing and excessive character
-    bool processCurrentNode(const int initialPos, const int initialDepth,
-            const int maxDepth, const bool initialTraverseAllNodes, int inputIndex,
-            const int initialDiffs,
+    bool processCurrentNode(const int initialPos, const int maxDepth,
+            const bool initialTraverseAllNodes, const int initialDiffs,
             CorrectionState *correctionState, int *newCount, int *newChildPosition,
-            bool *newTraverseAllNodes, int *newInputIndex, int *newDiffs,
-            int *nextSiblingPosition, int *nextOutputIndex);
+            bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition);
     int getMostFrequentWordLike(const int startInputIndex, const int inputLength,
             unsigned short *word);
     int getMostFrequentWordLikeInner(const uint16_t* const inWord, const int length,