Search bigrams for the lower case version of the word (A46)

...if there aren't any for the exact case version.

Bug: 6752830
Change-Id: I2737148b01ba04a64febe009ceb2ef53c265d224
diff --git a/java/src/com/android/inputmethod/latin/Suggest.java b/java/src/com/android/inputmethod/latin/Suggest.java
index dcfda86..f810ecc 100644
--- a/java/src/com/android/inputmethod/latin/Suggest.java
+++ b/java/src/com/android/inputmethod/latin/Suggest.java
@@ -177,19 +177,9 @@
         if (wordComposer.size() <= 1 && isCorrectionEnabled) {
             // At first character typed, search only the bigrams
             if (!TextUtils.isEmpty(prevWordForBigram)) {
-                final CharSequence lowerPrevWord;
-                if (StringUtils.hasUpperCase(prevWordForBigram)) {
-                    // TODO: Must pay attention to locale when changing case.
-                    lowerPrevWord = prevWordForBigram.toString().toLowerCase();
-                } else {
-                    lowerPrevWord = null;
-                }
                 for (final String key : mDictionaries.keySet()) {
                     final Dictionary dictionary = mDictionaries.get(key);
                     suggestionsSet.addAll(dictionary.getBigrams(wordComposer, prevWordForBigram));
-                    if (null != lowerPrevWord) {
-                        suggestionsSet.addAll(dictionary.getBigrams(wordComposer, lowerPrevWord));
-                    }
                 }
             }
         } else if (wordComposer.size() > 1) {
diff --git a/java/src/com/android/inputmethod/latin/UserHistoryDictionaryBigramList.java b/java/src/com/android/inputmethod/latin/UserHistoryDictionaryBigramList.java
index 2884774..610652a 100644
--- a/java/src/com/android/inputmethod/latin/UserHistoryDictionaryBigramList.java
+++ b/java/src/com/android/inputmethod/latin/UserHistoryDictionaryBigramList.java
@@ -98,11 +98,11 @@
     }
 
     public HashMap<String, Byte> getBigrams(String word1) {
-        if (!mBigramMap.containsKey(word1)) {
-            return EMPTY_BIGRAM_MAP;
-        } else {
-            return mBigramMap.get(word1);
-        }
+        if (mBigramMap.containsKey(word1)) return mBigramMap.get(word1);
+        // TODO: lower case according to locale
+        final String lowerWord1 = word1.toLowerCase();
+        if (mBigramMap.containsKey(lowerWord1)) return mBigramMap.get(lowerWord1);
+        return EMPTY_BIGRAM_MAP;
     }
 
     public boolean removeBigram(String word1, String word2) {
diff --git a/native/jni/src/bigram_dictionary.cpp b/native/jni/src/bigram_dictionary.cpp
index 1443369..3bfbfad 100644
--- a/native/jni/src/bigram_dictionary.cpp
+++ b/native/jni/src/bigram_dictionary.cpp
@@ -105,8 +105,15 @@
     // TODO: have "in" arguments before "out" ones, and make out args explicit in the name
 
     const uint8_t* const root = DICT;
-    int pos = getBigramListPositionForWord(prevWord, prevWordLength);
+    int pos = getBigramListPositionForWord(prevWord, prevWordLength,
+            false /* forceLowerCaseSearch */);
     // getBigramListPositionForWord returns 0 if this word isn't in the dictionary or has no bigrams
+    if (0 == pos) {
+        // If no bigrams for this exact word, search again in lower case.
+        pos = getBigramListPositionForWord(prevWord, prevWordLength,
+                true /* forceLowerCaseSearch */);
+    }
+    // If still no bigrams, we really don't have them!
     if (0 == pos) return 0;
     int bigramFlags;
     int bigramCount = 0;
@@ -141,10 +148,11 @@
 // Returns a pointer to the start of the bigram list.
 // If the word is not found or has no bigrams, this function returns 0.
 int BigramDictionary::getBigramListPositionForWord(const int32_t *prevWord,
-        const int prevWordLength) const {
+        const int prevWordLength, const bool forceLowerCaseSearch) const {
     if (0 >= prevWordLength) return 0;
     const uint8_t* const root = DICT;
-    int pos = BinaryFormat::getTerminalPosition(root, prevWord, prevWordLength);
+    int pos = BinaryFormat::getTerminalPosition(root, prevWord, prevWordLength,
+            forceLowerCaseSearch);
 
     if (NOT_VALID_WORD == pos) return 0;
     const int flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos);
@@ -164,7 +172,13 @@
         const int prevWordLength, std::map<int, int> *map, uint8_t *filter) const {
     memset(filter, 0, BIGRAM_FILTER_BYTE_SIZE);
     const uint8_t* const root = DICT;
-    int pos = getBigramListPositionForWord(prevWord, prevWordLength);
+    int pos = getBigramListPositionForWord(prevWord, prevWordLength,
+            false /* forceLowerCaseSearch */);
+    if (0 == pos) {
+        // If no bigrams for this exact string, search again in lower case.
+        pos = getBigramListPositionForWord(prevWord, prevWordLength,
+                true /* forceLowerCaseSearch */);
+    }
     if (0 == pos) return;
 
     int bigramFlags;
@@ -197,10 +211,11 @@
 bool BigramDictionary::isValidBigram(const int32_t *word1, int length1, const int32_t *word2,
         int length2) const {
     const uint8_t* const root = DICT;
-    int pos = getBigramListPositionForWord(word1, length1);
+    int pos = getBigramListPositionForWord(word1, length1, false /* forceLowerCaseSearch */);
     // getBigramListPositionForWord returns 0 if this word isn't in the dictionary or has no bigrams
     if (0 == pos) return false;
-    int nextWordPos = BinaryFormat::getTerminalPosition(root, word2, length2);
+    int nextWordPos = BinaryFormat::getTerminalPosition(root, word2, length2,
+            false /* forceLowerCaseSearch */);
     if (NOT_VALID_WORD == nextWordPos) return false;
     int bigramFlags;
     do {
diff --git a/native/jni/src/bigram_dictionary.h b/native/jni/src/bigram_dictionary.h
index 1ff1b2e..5372276 100644
--- a/native/jni/src/bigram_dictionary.h
+++ b/native/jni/src/bigram_dictionary.h
@@ -30,7 +30,8 @@
     BigramDictionary(const unsigned char *dict, int maxWordLength);
     int getBigrams(const int32_t *word, int length, int *inputCodes, int codesSize,
             unsigned short *outWords, int *frequencies, int maxWordLength, int maxBigrams) const;
-    int getBigramListPositionForWord(const int32_t *prevWord, const int prevWordLength) const;
+    int getBigramListPositionForWord(const int32_t *prevWord, const int prevWordLength,
+            const bool forceLowerCaseSearch) const;
     void fillBigramAddressToFrequencyMapAndFilter(const int32_t *prevWord, const int prevWordLength,
             std::map<int, int> *map, uint8_t *filter) const;
     bool isValidBigram(const int32_t *word1, int length1, const int32_t *word2, int length2) const;
diff --git a/native/jni/src/binary_format.h b/native/jni/src/binary_format.h
index 214ecfa..474c854 100644
--- a/native/jni/src/binary_format.h
+++ b/native/jni/src/binary_format.h
@@ -19,6 +19,7 @@
 
 #include <limits>
 #include "bloom_filter.h"
+#include "char_utils.h"
 #include "unigram_dictionary.h"
 
 namespace latinime {
@@ -65,7 +66,7 @@
     static int getAttributeAddressAndForwardPointer(const uint8_t* const dict, const uint8_t flags,
             int *pos);
     static int getTerminalPosition(const uint8_t* const root, const int32_t* const inWord,
-            const int length);
+            const int length, const bool forceLowerCaseSearch);
     static int getWordAtAddress(const uint8_t* const root, const int address, const int maxDepth,
             uint16_t* outWord, int* outUnigramFrequency);
     static int computeFrequencyForBigram(const int unigramFreq, const int bigramFreq);
@@ -309,7 +310,7 @@
 // This function gets the byte position of the last chargroup of the exact matching word in the
 // dictionary. If no match is found, it returns NOT_VALID_WORD.
 inline int BinaryFormat::getTerminalPosition(const uint8_t* const root,
-        const int32_t* const inWord, const int length) {
+        const int32_t* const inWord, const int length, const bool forceLowerCaseSearch) {
     int pos = 0;
     int wordPos = 0;
 
@@ -318,7 +319,7 @@
         // there was no match (or we would have found it).
         if (wordPos > length) return NOT_VALID_WORD;
         int charGroupCount = BinaryFormat::getGroupCountAndForwardPointer(root, &pos);
-        const int32_t wChar = inWord[wordPos];
+        const int32_t wChar = forceLowerCaseSearch ? toLowerCase(inWord[wordPos]) : inWord[wordPos];
         while (true) {
             // If there are no more character groups in this node, it means we could not
             // find a matching character for this depth, therefore there is no match.
diff --git a/native/jni/src/char_utils.h b/native/jni/src/char_utils.h
index 607dc51..21dca9a 100644
--- a/native/jni/src/char_utils.h
+++ b/native/jni/src/char_utils.h
@@ -50,8 +50,7 @@
     return c;
 }
 
-inline static unsigned short toBaseLowerCase(unsigned short c) {
-    c = toBaseChar(c);
+inline static unsigned short toLowerCase(const unsigned short c) {
     if (isAsciiUpper(c)) {
         return toAsciiLower(c);
     } else if (isAscii(c)) {
@@ -60,6 +59,10 @@
     return latin_tolower(c);
 }
 
+inline static unsigned short toBaseLowerCase(const unsigned short c) {
+    return toLowerCase(toBaseChar(c));
+}
+
 } // namespace latinime
 
 #endif // LATINIME_CHAR_UTILS_H
diff --git a/native/jni/src/unigram_dictionary.cpp b/native/jni/src/unigram_dictionary.cpp
index 3417d2b..22f1657 100644
--- a/native/jni/src/unigram_dictionary.cpp
+++ b/native/jni/src/unigram_dictionary.cpp
@@ -817,7 +817,8 @@
 
 int UnigramDictionary::getFrequency(const int32_t* const inWord, const int length) const {
     const uint8_t* const root = DICT_ROOT;
-    int pos = BinaryFormat::getTerminalPosition(root, inWord, length);
+    int pos = BinaryFormat::getTerminalPosition(root, inWord, length,
+            false /* forceLowerCaseSearch */);
     if (NOT_VALID_WORD == pos) {
         return NOT_A_PROBABILITY;
     }