Add correction state
Change-Id: I0d281cede1590893bd1def005cf83c9431d12750
diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp
index bccd37a..27d1515 100644
--- a/native/src/unigram_dictionary.cpp
+++ b/native/src/unigram_dictionary.cpp
@@ -58,9 +58,12 @@
if (DEBUG_DICT) {
LOGI("UnigramDictionary - constructor");
}
+ mCorrectionState = new CorrectionState();
}
-UnigramDictionary::~UnigramDictionary() {}
+UnigramDictionary::~UnigramDictionary() {
+ delete mCorrectionState;
+}
static inline unsigned int getCodesBufferSize(const int* codes, const int codesSize,
const int MAX_PROXIMITY_CHARS) {
@@ -362,6 +365,8 @@
assert(excessivePos < mInputLength);
assert(missingPos < mInputLength);
}
+ mCorrectionState->setCorrectionParams(mProximityInfo, mInputLength, skipPos, excessivePos,
+ transposedPos);
int rootPosition = ROOT_POS;
// Get the number of children of root, then increment the position
int childCount = Dictionary::getCount(DICT_ROOT, &rootPosition);
@@ -389,8 +394,8 @@
// depth will never be greater than maxDepth because in that case,
// needsToTraverseChildrenNodes should be false
const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, outputIndex,
- maxDepth, traverseAllNodes, matchWeight, inputIndex, diffs, skipPos,
- excessivePos, transposedPos, nextLetters, nextLettersSize, &childCount,
+ maxDepth, traverseAllNodes, matchWeight, inputIndex, diffs,
+ nextLetters, nextLettersSize, mCorrectionState, &childCount,
&firstChildPos, &traverseAllNodes, &matchWeight, &inputIndex, &diffs,
&siblingPos, &outputIndex);
// Update next sibling pos
@@ -521,8 +526,12 @@
}
inline int UnigramDictionary::calculateFinalFreq(const int inputIndex, const int depth,
- const int matchWeight, const int skipPos, const int excessivePos, const int transposedPos,
- const int freq, const bool sameLength) const {
+ const int matchWeight, const int freq, const bool sameLength,
+ CorrectionState *correctionState) const {
+ const int skipPos = correctionState->getSkipPos();
+ const int excessivePos = correctionState->getExcessivePos();
+ const int transposedPos = correctionState->getTransposedPos();
+
// TODO: Demote by edit distance
int finalFreq = freq * matchWeight;
if (skipPos >= 0) {
@@ -587,16 +596,16 @@
inline void UnigramDictionary::onTerminal(unsigned short int* word, const int depth,
const uint8_t* const root, const uint8_t flags, const int pos,
- const int inputIndex, const int matchWeight, const int skipPos,
- const int excessivePos, const int transposedPos, const int freq, const bool sameLength,
- int* nextLetters, const int nextLettersSize) {
+ const int inputIndex, const int matchWeight, const int freq, const bool sameLength,
+ int* nextLetters, const int nextLettersSize, CorrectionState *correctionState) {
+ const int skipPos = correctionState->getSkipPos();
const bool isSameAsTyped = sameLength ? mProximityInfo->sameAsTyped(word, depth + 1) : false;
if (isSameAsTyped) return;
if (depth >= MIN_SUGGEST_DEPTH) {
- const int finalFreq = calculateFinalFreq(inputIndex, depth, matchWeight, skipPos,
- excessivePos, transposedPos, freq, sameLength);
+ const int finalFreq = calculateFinalFreq(inputIndex, depth, matchWeight,
+ freq, sameLength, correctionState);
if (!isSameAsTyped)
addWord(word, depth + 1, finalFreq);
}
@@ -648,48 +657,6 @@
}
#ifndef NEW_DICTIONARY_FORMAT
-// The following functions will be entirely replaced with new implementations.
-void UnigramDictionary::getWordsOld(const int initialPos, const int inputLength, const int skipPos,
- const int excessivePos, const int transposedPos,int *nextLetters,
- const int nextLettersSize) {
- int initialPosition = initialPos;
- const int count = Dictionary::getCount(DICT_ROOT, &initialPosition);
- getWordsRec(count, initialPosition, 0,
- min(inputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH),
- mInputLength <= 0, 1, 0, 0, skipPos, excessivePos, transposedPos, nextLetters,
- nextLettersSize);
-}
-
-void UnigramDictionary::getWordsRec(const int childrenCount, const int pos, const int depth,
- const int maxDepth, const bool traverseAllNodes, const int matchWeight,
- const int inputIndex, const int diffs, const int skipPos, const int excessivePos,
- const int transposedPos, int *nextLetters, const int nextLettersSize) {
- int siblingPos = pos;
- for (int i = 0; i < childrenCount; ++i) {
- int newCount;
- int newChildPosition;
- bool newTraverseAllNodes;
- int newMatchRate;
- int newInputIndex;
- int newDiffs;
- int newSiblingPos;
- int newOutputIndex;
- const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos, depth, maxDepth,
- traverseAllNodes, matchWeight, inputIndex, diffs,
- skipPos, excessivePos, transposedPos,
- nextLetters, nextLettersSize,
- &newCount, &newChildPosition, &newTraverseAllNodes, &newMatchRate,
- &newInputIndex, &newDiffs, &newSiblingPos, &newOutputIndex);
- siblingPos = newSiblingPos;
-
- if (needsToTraverseChildrenNodes) {
- getWordsRec(newCount, newChildPosition, newOutputIndex, maxDepth, newTraverseAllNodes,
- newMatchRate, newInputIndex, newDiffs, skipPos, excessivePos, transposedPos,
- nextLetters, nextLettersSize);
- }
- }
-}
-
inline int UnigramDictionary::getMostFrequentWordLike(const int startInputIndex,
const int inputLength, unsigned short *word) {
int pos = ROOT_POS;
@@ -829,10 +796,13 @@
// The following functions will be modified.
inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialDepth,
const int maxDepth, const bool initialTraverseAllNodes, int matchWeight, int inputIndex,
- const int initialDiffs, const int skipPos, const int excessivePos, const int transposedPos,
- int *nextLetters, const int nextLettersSize, int *newCount, int *newChildPosition,
+ const int initialDiffs, int *nextLetters, const int nextLettersSize,
+ CorrectionState *correctionState, int *newCount, int *newChildPosition,
bool *newTraverseAllNodes, int *newMatchRate, int *newInputIndex, int *newDiffs,
int *nextSiblingPosition, int *nextOutputIndex) {
+ const int skipPos = correctionState->getSkipPos();
+ const int excessivePos = correctionState->getExcessivePos();
+ const int transposedPos = correctionState->getTransposedPos();
if (DEBUG_DICT) {
int inputCount = 0;
if (skipPos >= 0) ++inputCount;
@@ -865,8 +835,8 @@
if (traverseAllNodes || needsToSkipCurrentNode(c, inputIndex, skipPos, depth)) {
mWord[depth] = c;
if (traverseAllNodes && terminal) {
- onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, skipPos,
- excessivePos, transposedPos, freq, false, nextLetters, nextLettersSize);
+ onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight,
+ freq, false, nextLetters, nextLettersSize, mCorrectionState);
}
if (!needsToTraverseChildrenNodes) return false;
*newTraverseAllNodes = traverseAllNodes;
@@ -882,7 +852,7 @@
}
ProximityInfo::ProximityType matchedProximityCharId = mProximityInfo->getMatchedProximityId(
- inputIndexForProximity, c, skipPos, excessivePos, transposedPos);
+ inputIndexForProximity, c, mCorrectionState);
if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) return false;
mWord[depth] = c;
// If inputIndex is greater than mInputLength, that means there is no
@@ -893,8 +863,8 @@
bool isSameAsUserTypedLength = mInputLength == inputIndex + 1
|| (excessivePos == mInputLength - 1 && inputIndex == mInputLength - 2);
if (isSameAsUserTypedLength && terminal) {
- onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, skipPos,
- excessivePos, transposedPos, freq, true, nextLetters, nextLettersSize);
+ onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight,
+ freq, true, nextLetters, nextLettersSize, mCorrectionState);
}
if (!needsToTraverseChildrenNodes) return false;
// Start traversing all nodes after the index exceeds the user typed length
@@ -1081,16 +1051,15 @@
// given level, as output into newCount when traversing this level's parent.
inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int initialDepth,
const int maxDepth, const bool initialTraverseAllNodes, int matchWeight, int inputIndex,
- const int initialDiffs, const int skipPos, const int excessivePos, const int transposedPos,
- int *nextLetters, const int nextLettersSize, int *newCount, int *newChildrenPosition,
+ const int initialDiffs, int *nextLetters, const int nextLettersSize,
+ CorrectionState *correctionState, int *newCount, int *newChildrenPosition,
bool *newTraverseAllNodes, int *newMatchRate, int *newInputIndex, int *newDiffs,
int *nextSiblingPosition, int *newOutputIndex) {
+ const int skipPos = correctionState->getSkipPos();
+ const int excessivePos = correctionState->getExcessivePos();
+ const int transposedPos = correctionState->getTransposedPos();
if (DEBUG_DICT) {
- int inputCount = 0;
- if (skipPos >= 0) ++inputCount;
- if (excessivePos >= 0) ++inputCount;
- if (transposedPos >= 0) ++inputCount;
- assert(inputCount <= 1);
+ correctionState->checkState();
}
int pos = initialPos;
int depth = initialDepth;
@@ -1146,8 +1115,8 @@
// The frequency should be here, because we come here only if this is actually
// a terminal node, and we are on its last char.
const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
- onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, skipPos,
- excessivePos, transposedPos, freq, false, nextLetters, nextLettersSize);
+ onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight,
+ freq, false, nextLetters, nextLettersSize, mCorrectionState);
}
if (!hasChildren) {
// If we don't have children here, that means we finished processing all
@@ -1170,7 +1139,7 @@
}
int matchedProximityCharId = mProximityInfo->getMatchedProximityId(
- inputIndexForProximity, c, skipPos, excessivePos, transposedPos);
+ inputIndexForProximity, c, mCorrectionState);
if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) {
// We found that this is an unrelated character, so we should give up traversing
// this node and its children entirely.
@@ -1197,8 +1166,8 @@
|| (excessivePos == mInputLength - 1 && inputIndex == mInputLength - 2);
if (isSameAsUserTypedLength && isTerminal) {
const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
- onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight, skipPos,
- excessivePos, transposedPos, freq, true, nextLetters, nextLettersSize);
+ onTerminal(mWord, depth, DICT_ROOT, flags, pos, inputIndex, matchWeight,
+ freq, true, nextLetters, nextLettersSize, mCorrectionState);
}
// This character matched the typed character (enough to traverse the node at least)
// so we just evaluated it. Now we should evaluate this virtual node's children - that