Combilne normal correction and skip correction

Change-Id: Ide868d977c0f35900340c7be1b71d572c69a8806
diff --git a/native/src/correction.cpp b/native/src/correction.cpp
index f8f73dd..a4090a9 100644
--- a/native/src/correction.cpp
+++ b/native/src/correction.cpp
@@ -21,6 +21,7 @@
 #define LOG_TAG "LatinIME: correction.cpp"
 
 #include "correction.h"
+#include "dictionary.h"
 #include "proximity_info.h"
 
 namespace latinime {
@@ -93,16 +94,11 @@
         return -1;
     }
 
-    // TODO: Remove this
-    if (mSkipPos >= 0 && mSkippedCount <= 0) {
-        return -1;
-    }
-
     *word = mWord;
     const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2)
             : (mInputLength == inputIndex + 1);
     return Correction::RankingAlgorithm::calculateFinalFreq(
-            inputIndex, outputIndex, freq, sameLength, this);
+            inputIndex, outputIndex, freq, sameLength, mEditDistanceTable, this);
 }
 
 bool Correction::initProcessState(const int outputIndex) {
@@ -117,6 +113,7 @@
     mSkippedCount = mCorrectionStates[outputIndex].mSkippedCount;
     mSkipPos = mCorrectionStates[outputIndex].mSkipPos;
     mSkipping = false;
+    mProximityMatching = false;
     mMatching = false;
     return true;
 }
@@ -160,6 +157,7 @@
     mCorrectionStates[mOutputIndex].mSkipping = mSkipping;
     mCorrectionStates[mOutputIndex].mSkipPos = mSkipPos;
     mCorrectionStates[mOutputIndex].mMatching = mMatching;
+    mCorrectionStates[mOutputIndex].mProximityMatching = mProximityMatching;
 }
 
 void Correction::startToTraverseAllNodes() {
@@ -207,6 +205,20 @@
     }
 
     if (mNeedsToTraverseAllNodes || isQuote(c)) {
+        const bool checkProximityChars =
+                !(mSkippedCount > 0 || mExcessivePos >= 0 || mTransposedPos >= 0);
+        // Note: This logic tries saving cases like contrst --> contrast -- "a" is one of
+        // proximity chars of "s", but it should rather be handled as a skipped char.
+        if (checkProximityChars
+                && mInputIndex > 0
+                && mCorrectionStates[mOutputIndex].mProximityMatching
+                && mCorrectionStates[mOutputIndex].mSkipping
+                && mProximityInfo->getMatchedProximityId(
+                        mInputIndex - 1, c, false)
+                        == ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR) {
+            ++mSkippedCount;
+            --mProximityCount;
+        }
         return processSkipChar(c, isTerminal);
     } else {
         int inputIndexForProximity = mInputIndex;
@@ -220,16 +232,27 @@
             }
         }
 
+        // TODO: sum counters
         const bool checkProximityChars =
-                !(mSkipPos >= 0 || mExcessivePos >= 0 || mTransposedPos >= 0);
+                !(mSkippedCount > 0 || mExcessivePos >= 0 || mTransposedPos >= 0);
         int matchedProximityCharId = mProximityInfo->getMatchedProximityId(
                 inputIndexForProximity, c, checkProximityChars);
 
         if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) {
-            if (skip) {
+            if (skip && mProximityCount == 0) {
                 // Skip this letter and continue deeper
                 ++mSkippedCount;
                 return processSkipChar(c, isTerminal);
+            } else if (checkProximityChars
+                    && inputIndexForProximity > 0
+                    && mCorrectionStates[mOutputIndex].mProximityMatching
+                    && mCorrectionStates[mOutputIndex].mSkipping
+                    && mProximityInfo->getMatchedProximityId(
+                            inputIndexForProximity - 1, c, false)
+                                    == ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR) {
+                ++mSkippedCount;
+                --mProximityCount;
+                return processSkipChar(c, isTerminal);
             } else {
                 return UNRELATED;
             }
@@ -238,6 +261,7 @@
             // proximity chars. So, we don't need to check proximity.
             mMatching = true;
         } else if (ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) {
+            mProximityMatching = true;
             incrementProximityCount();
         }
 
@@ -320,29 +344,116 @@
     }
 }
 
+/* static */
+inline static int editDistance(
+        int* editDistanceTable, const unsigned short* input,
+        const int inputLength, const unsigned short* output, const int outputLength) {
+    // dp[li][lo] dp[a][b] = dp[ a * lo + b]
+    int* dp = editDistanceTable;
+    const int li = inputLength + 1;
+    const int lo = outputLength + 1;
+    for (int i = 0; i < li; ++i) {
+        dp[lo * i] = i;
+    }
+    for (int i = 0; i < lo; ++i) {
+        dp[i] = i;
+    }
+
+    for (int i = 0; i < li - 1; ++i) {
+        for (int j = 0; j < lo - 1; ++j) {
+            const uint32_t ci = Dictionary::toBaseLowerCase(input[i]);
+            const uint32_t co = Dictionary::toBaseLowerCase(output[j]);
+            const uint16_t cost = (ci == co) ? 0 : 1;
+            dp[(i + 1) * lo + (j + 1)] = min(dp[i * lo + (j + 1)] + 1,
+                    min(dp[(i + 1) * lo + j] + 1, dp[i * lo + j] + cost));
+            if (li > 0 && lo > 0
+                    && ci == Dictionary::toBaseLowerCase(output[j - 1])
+                    && co == Dictionary::toBaseLowerCase(input[i - 1])) {
+                dp[(i + 1) * lo + (j + 1)] = min(
+                        dp[(i + 1) * lo + (j + 1)], dp[(i - 1) * lo + (j - 1)] + cost);
+            }
+        }
+    }
+
+    if (DEBUG_EDIT_DISTANCE) {
+        LOGI("IN = %d, OUT = %d", inputLength, outputLength);
+        for (int i = 0; i < li; ++i) {
+            for (int j = 0; j < lo; ++j) {
+                LOGI("EDIT[%d][%d], %d", i, j, dp[i * lo + j]);
+            }
+        }
+    }
+    return dp[li * lo - 1];
+}
+
 //////////////////////
 // RankingAlgorithm //
 //////////////////////
 
 /* static */
 int Correction::RankingAlgorithm::calculateFinalFreq(const int inputIndex, const int outputIndex,
-        const int freq, const bool sameLength, const Correction* correction) {
+        const int freq, const bool sameLength, int* editDistanceTable,
+        const Correction* correction) {
     const int excessivePos = correction->getExcessivePos();
     const int transposedPos = correction->getTransposedPos();
     const int inputLength = correction->mInputLength;
     const int typedLetterMultiplier = correction->TYPED_LETTER_MULTIPLIER;
     const int fullWordMultiplier = correction->FULL_WORD_MULTIPLIER;
     const ProximityInfo *proximityInfo = correction->mProximityInfo;
+    const int skipCount = correction->mSkippedCount;
+    const int proximityMatchedCount = correction->mProximityCount;
 
     // TODO: use mExcessiveCount
-    const int matchCount = inputLength - correction->mProximityCount - (excessivePos >= 0 ? 1 : 0);
-    const int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount);
+    int matchCount = inputLength - correction->mProximityCount - (excessivePos >= 0 ? 1 : 0);
 
     const unsigned short* word = correction->mWord;
-    const bool skipped = correction->mSkippedCount > 0;
+    const bool skipped = skipCount > 0;
+
+    // ----- TODO: use edit distance here as follows? ---------------------- /
+    //if (!skipped && excessivePos < 0 && transposedPos < 0) {
+    //    const int ed = editDistance(dp, proximityInfo->getInputWord(),
+    //            inputLength, word, outputIndex + 1);
+    //    matchCount = outputIndex + 1 - ed;
+    //    if (ed == 1 && !sameLength) ++matchCount;
+    //}
+    //    const int ed = editDistance(dp, proximityInfo->getInputWord(),
+    //    inputLength, word, outputIndex + 1);
+    //    if (ed == 1 && !sameLength) ++matchCount; ------------------------ /
+    int matchWeight = powerIntCapped(typedLetterMultiplier, matchCount);
 
     // TODO: Demote by edit distance
     int finalFreq = freq * matchWeight;
+    // +1 +11/-12
+    /*if (inputLength == outputIndex && !skipped && excessivePos < 0 && transposedPos < 0) {
+        const int ed = editDistance(dp, proximityInfo->getInputWord(),
+                inputLength, word, outputIndex + 1);
+        if (ed == 1) {
+            multiplyRate(160, &finalFreq);
+        }
+    }*/
+    if (inputLength == outputIndex && excessivePos < 0 && transposedPos < 0
+            && (proximityMatchedCount > 0 || skipped)) {
+        const int ed = editDistance(editDistanceTable, proximityInfo->getPrimaryInputWord(),
+                inputLength, word, outputIndex + 1);
+        if (ed == 1) {
+            multiplyRate(160, &finalFreq);
+        }
+    }
+
+    // TODO: Promote properly?
+    //if (skipCount == 1 && excessivePos < 0 && transposedPos < 0 && inputLength == outputIndex
+    //        && !sameLength) {
+    //    multiplyRate(150, &finalFreq);
+    //}
+    //if (skipCount == 0 && excessivePos < 0 && transposedPos < 0 && inputLength == outputIndex
+    //        && !sameLength) {
+    //    multiplyRate(150, &finalFreq);
+    //}
+    //if (skipCount == 0 && excessivePos < 0 && transposedPos < 0
+    //        && inputLength == outputIndex + 1) {
+    //    multiplyRate(150, &finalFreq);
+    //}
+
     if (skipped) {
         if (inputLength >= 2) {
             const int demotionRate = WORDS_WITH_MISSING_CHARACTER_DEMOTION_RATE
@@ -389,7 +500,7 @@
         multiplyIntCapped(typedLetterMultiplier, &finalFreq);
         multiplyRate(WORDS_WITH_PROXIMITY_CHARACTER_DEMOTION_RATE, &finalFreq);
     }
-    if (DEBUG_DICT) {
+    if (DEBUG_DICT_FULL) {
         LOGI("calc: %d, %d", outputIndex, sameLength);
     }
     if (sameLength) multiplyIntCapped(fullWordMultiplier, &finalFreq);
diff --git a/native/src/correction.h b/native/src/correction.h
index 2fa8c90..9d385a4 100644
--- a/native/src/correction.h
+++ b/native/src/correction.h
@@ -120,6 +120,8 @@
     int mTerminalInputIndex;
     int mTerminalOutputIndex;
     unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
+    // Caveat: Do not create multiple tables per thread as this table eats up RAM a lot.
+    int mEditDistanceTable[MAX_WORD_LENGTH_INTERNAL * MAX_WORD_LENGTH_INTERNAL];
 
     CorrectionState mCorrectionStates[MAX_WORD_LENGTH_INTERNAL];
 
@@ -132,11 +134,13 @@
     bool mNeedsToTraverseAllNodes;
     bool mMatching;
     bool mSkipping;
+    bool mProximityMatching;
 
     class RankingAlgorithm {
     public:
         static int calculateFinalFreq(const int inputIndex, const int depth,
-                const int freq, const bool sameLength, const Correction* correction);
+                const int freq, const bool sameLength, int *editDistanceTable,
+                const Correction* correction);
         static int calcFreqForSplitTwoWords(const int firstFreq, const int secondFreq,
                 const Correction* correction);
     };
diff --git a/native/src/correction_state.h b/native/src/correction_state.h
index d30d13c..267deda 100644
--- a/native/src/correction_state.h
+++ b/native/src/correction_state.h
@@ -33,6 +33,7 @@
     int8_t mSkipPos; // should be signed
     bool mMatching;
     bool mSkipping;
+    bool mProximityMatching;
     bool mNeedsToTraverseAllNodes;
 
 };
@@ -47,6 +48,7 @@
     state->mSkippedCount = 0;
     state->mMatching = false;
     state->mSkipping = false;
+    state->mProximityMatching = false;
     state->mNeedsToTraverseAllNodes = traverseAll;
     state->mSkipPos = -1;
 }
diff --git a/native/src/defines.h b/native/src/defines.h
index c1838d3..c1d08e6 100644
--- a/native/src/defines.h
+++ b/native/src/defines.h
@@ -94,20 +94,36 @@
 #endif
 #define DEBUG_DICT true
 #define DEBUG_DICT_FULL false
+#define DEBUG_EDIT_DISTANCE false
 #define DEBUG_SHOW_FOUND_WORD DEBUG_DICT_FULL
 #define DEBUG_NODE DEBUG_DICT_FULL
 #define DEBUG_TRACE DEBUG_DICT_FULL
 #define DEBUG_PROXIMITY_INFO true
 
+#define DUMP_WORD(word, length) do { dumpWord(word, length); } while(0)
+
+static char charBuf[50];
+
+static void dumpWord(const unsigned short* word, const int length) {
+    for (int i = 0; i < length; ++i) {
+        charBuf[i] = word[i];
+    }
+    charBuf[length] = 0;
+    LOGI("[ %s ]", charBuf);
+}
+
 #else // FLAG_DBG
 
 #define DEBUG_DICT false
 #define DEBUG_DICT_FULL false
+#define DEBUG_EDIT_DISTANCE false
 #define DEBUG_SHOW_FOUND_WORD false
 #define DEBUG_NODE false
 #define DEBUG_TRACE false
 #define DEBUG_PROXIMITY_INFO false
 
+#define DUMP_WORD(word, length)
+
 #endif // FLAG_DBG
 
 #ifndef U_SHORT_MAX
diff --git a/native/src/proximity_info.cpp b/native/src/proximity_info.cpp
index d437e25..361bdac 100644
--- a/native/src/proximity_info.cpp
+++ b/native/src/proximity_info.cpp
@@ -68,6 +68,10 @@
 void ProximityInfo::setInputParams(const int* inputCodes, const int inputLength) {
     mInputCodes = inputCodes;
     mInputLength = inputLength;
+    for (int i = 0; i < inputLength; ++i) {
+        mPrimaryInputWord[i] = getPrimaryCharAt(i);
+    }
+    mPrimaryInputWord[inputLength] = 0;
 }
 
 inline const int* ProximityInfo::getProximityCharsAt(const int index) const {
diff --git a/native/src/proximity_info.h b/native/src/proximity_info.h
index d9ed46f..75fc8fb 100644
--- a/native/src/proximity_info.h
+++ b/native/src/proximity_info.h
@@ -46,6 +46,9 @@
     ProximityType getMatchedProximityId(
             const int index, const unsigned short c, const bool checkProximityChars) const;
     bool sameAsTyped(const unsigned short *word, int length) const;
+    const unsigned short* getPrimaryInputWord() const {
+        return mPrimaryInputWord;
+    }
 
 private:
     int getStartIndexFromCoordinates(const int x, const int y) const;
@@ -59,6 +62,7 @@
     const int *mInputCodes;
     uint32_t *mProximityCharsArray;
     int mInputLength;
+    unsigned short mPrimaryInputWord[MAX_WORD_LENGTH_INTERNAL];
 };
 
 } // namespace latinime
diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp
index 6517bc0..6bc3505 100644
--- a/native/src/unigram_dictionary.cpp
+++ b/native/src/unigram_dictionary.cpp
@@ -187,8 +187,9 @@
     mCorrection->initCorrection(mProximityInfo, mInputLength, maxDepth);
     PROF_END(0);
 
+    // TODO: remove
     PROF_START(1);
-    getSuggestionCandidates(-1, -1, -1);
+    // Note: This line is intentionally left blank
     PROF_END(1);
 
     PROF_START(2);