Implement getMaxProbabilityOfExactMatches().

Bug: 13142176
Bug: 15428247
Change-Id: I5be6d683be95505412615ca7c88260de1ea05f54
diff --git a/java/src/com/android/inputmethod/latin/BinaryDictionary.java b/java/src/com/android/inputmethod/latin/BinaryDictionary.java
index 284dadd..7247a1f 100644
--- a/java/src/com/android/inputmethod/latin/BinaryDictionary.java
+++ b/java/src/com/android/inputmethod/latin/BinaryDictionary.java
@@ -356,6 +356,7 @@
         return getProbabilityNative(mNativeDict, codePoints);
     }
 
+    @Override
     public int getMaxFrequencyOfExactMatches(final String word) {
         if (TextUtils.isEmpty(word)) return NOT_A_PROBABILITY;
         int[] codePoints = StringUtils.toCodePointArray(word);
diff --git a/native/jni/NativeFileList.mk b/native/jni/NativeFileList.mk
index cb337e6..07a82a9 100644
--- a/native/jni/NativeFileList.mk
+++ b/native/jni/NativeFileList.mk
@@ -28,6 +28,7 @@
     $(addprefix suggest/core/dictionary/, \
         bigram_dictionary.cpp \
         dictionary.cpp \
+        dictionary_utils.cpp \
         digraph_utils.cpp \
         error_type_utils.cpp \
         multi_bigram_map.cpp \
diff --git a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
index bbeb8dd..476338e 100644
--- a/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
+++ b/native/jni/com_android_inputmethod_latin_BinaryDictionary.cpp
@@ -280,8 +280,7 @@
     const jsize wordLength = env->GetArrayLength(word);
     int codePoints[wordLength];
     env->GetIntArrayRegion(word, 0, wordLength, codePoints);
-    // TODO: Implement.
-    return NOT_A_PROBABILITY;
+    return dictionary->getMaxProbabilityOfExactMatches(codePoints, wordLength);
 }
 
 static jint latinime_BinaryDictionary_getBigramProbability(JNIEnv *env, jclass clazz,
diff --git a/native/jni/src/suggest/core/dicnode/dic_node.h b/native/jni/src/suggest/core/dicnode/dic_node.h
index ef03d2b..92f39ea 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node.h
@@ -125,7 +125,7 @@
         PROF_NODE_COPY(&dicNode->mProfiler, mProfiler);
     }
 
-    void initAsPassingChild(DicNode *parentDicNode) {
+    void initAsPassingChild(const DicNode *parentDicNode) {
         mIsCachedForNextSuggestion = parentDicNode->mIsCachedForNextSuggestion;
         const int codePoint =
                 parentDicNode->mDicNodeState.mDicNodeStateOutput.getCurrentWordCodePointAt(
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
index bf2a000..4445f4a 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
+++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
@@ -48,7 +48,7 @@
 ///////////////////////////////////
 // Traverse node expansion utils //
 ///////////////////////////////////
-/* static */ void DicNodeUtils::getAllChildDicNodes(DicNode *dicNode,
+/* static */ void DicNodeUtils::getAllChildDicNodes(const DicNode *dicNode,
         const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
         DicNodeVector *const childDicNodes) {
     if (dicNode->isTotalInputSizeExceedingLimit()) {
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h
index 0d60e57..00e80c6 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h
@@ -35,7 +35,7 @@
             const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
             const DicNode *const prevWordLastDicNode, DicNode *const newRootDicNode);
     static void initByCopy(const DicNode *const srcDicNode, DicNode *const destDicNode);
-    static void getAllChildDicNodes(DicNode *dicNode,
+    static void getAllChildDicNodes(const DicNode *dicNode,
             const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
             DicNodeVector *childDicNodes);
     static float getBigramNodeImprobability(
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_vector.h b/native/jni/src/suggest/core/dicnode/dic_node_vector.h
index cb28e57..54cde19 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_vector.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_vector.h
@@ -52,7 +52,7 @@
         return static_cast<int>(mDicNodes.size());
     }
 
-    void pushPassingChild(DicNode *dicNode) {
+    void pushPassingChild(const DicNode *dicNode) {
         ASSERT(!mLock);
         mDicNodes.emplace_back();
         mDicNodes.back().initAsPassingChild(dicNode);
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp
index 898b44f..f88388c 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp
@@ -19,6 +19,7 @@
 #include "suggest/core/dictionary/dictionary.h"
 
 #include "defines.h"
+#include "suggest/core/dictionary/dictionary_utils.h"
 #include "suggest/core/policy/dictionary_header_structure_policy.h"
 #include "suggest/core/result/suggestion_results.h"
 #include "suggest/core/session/dic_traverse_session.h"
@@ -74,6 +75,12 @@
     return getDictionaryStructurePolicy()->getUnigramProbabilityOfPtNode(pos);
 }
 
+int Dictionary::getMaxProbabilityOfExactMatches(const int *word, int length) const {
+    TimeKeeper::setCurrentTime();
+    return DictionaryUtils::getMaxProbabilityOfExactMatches(
+            mDictionaryStructureWithBufferPolicy.get(), word, length);
+}
+
 int Dictionary::getBigramProbability(const PrevWordsInfo *const prevWordsInfo, const int *word,
         int length) const {
     TimeKeeper::setCurrentTime();
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h
index f6d406f..10010b2 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.h
+++ b/native/jni/src/suggest/core/dictionary/dictionary.h
@@ -73,6 +73,8 @@
 
     int getProbability(const int *word, int length) const;
 
+    int getMaxProbabilityOfExactMatches(const int *word, int length) const;
+
     int getBigramProbability(const PrevWordsInfo *const prevWordsInfo,
             const int *word, int length) const;
 
diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp
new file mode 100644
index 0000000..b94966c
--- /dev/null
+++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.cpp
@@ -0,0 +1,96 @@
+/*
+ * Copyright (C) 2014, The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "suggest/core/dictionary/dictionary_utils.h"
+
+#include "suggest/core/dicnode/dic_node.h"
+#include "suggest/core/dicnode/dic_node_priority_queue.h"
+#include "suggest/core/dicnode/dic_node_vector.h"
+#include "suggest/core/dictionary/dictionary.h"
+#include "suggest/core/dictionary/digraph_utils.h"
+#include "suggest/core/session/prev_words_info.h"
+#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
+
+namespace latinime {
+
+/* static */ int DictionaryUtils::getMaxProbabilityOfExactMatches(
+        const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
+        const int *const codePoints, const int codePointCount) {
+    std::vector<DicNode> current;
+    std::vector<DicNode> next;
+
+    // No prev words information.
+    PrevWordsInfo emptyPrevWordsInfo;
+    int prevWordsPtNodePos[MAX_PREV_WORD_COUNT_FOR_N_GRAM];
+    emptyPrevWordsInfo.getPrevWordsTerminalPtNodePos(dictionaryStructurePolicy,
+            prevWordsPtNodePos, false /* tryLowerCaseSearch */);
+    current.emplace_back();
+    DicNodeUtils::initAsRoot(dictionaryStructurePolicy, prevWordsPtNodePos, &current.front());
+    for (int i = 0; i < codePointCount; ++i) {
+        // The base-lower input is used to ignore case errors and accent errors.
+        const int codePoint = CharUtils::toBaseLowerCase(codePoints[i]);
+        for (const DicNode &dicNode : current) {
+            if (dicNode.isInDigraph() && dicNode.getNodeCodePoint() == codePoint) {
+                next.emplace_back(dicNode);
+                next.back().advanceDigraphIndex();
+                continue;
+            }
+            processChildDicNodes(dictionaryStructurePolicy, codePoint, &dicNode, &next);
+        }
+        current.clear();
+        current.swap(next);
+    }
+
+    int maxProbability = NOT_A_PROBABILITY;
+    for (const DicNode &dicNode : current) {
+        if (!dicNode.isTerminalDicNode()) {
+            continue;
+        }
+        // dicNode can contain case errors, accent errors, intentional omissions or digraphs.
+        maxProbability = std::max(maxProbability, dicNode.getProbability());
+    }
+    return maxProbability;
+}
+
+/* static */ void DictionaryUtils::processChildDicNodes(
+        const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
+        const int inputCodePoint, const DicNode *const parentDicNode,
+        std::vector<DicNode> *const outDicNodes) {
+    DicNodeVector childDicNodes;
+    DicNodeUtils::getAllChildDicNodes(parentDicNode, dictionaryStructurePolicy, &childDicNodes);
+    for (int childIndex = 0; childIndex < childDicNodes.getSizeAndLock(); ++childIndex) {
+        DicNode *const childDicNode = childDicNodes[childIndex];
+        const int codePoint = CharUtils::toBaseLowerCase(childDicNode->getNodeCodePoint());
+        if (inputCodePoint == codePoint) {
+            outDicNodes->emplace_back(*childDicNode);
+        }
+        if (childDicNode->canBeIntentionalOmission()) {
+            processChildDicNodes(dictionaryStructurePolicy, inputCodePoint, childDicNode,
+                    outDicNodes);
+        }
+        if (DigraphUtils::hasDigraphForCodePoint(
+                dictionaryStructurePolicy->getHeaderStructurePolicy(),
+                childDicNode->getNodeCodePoint())) {
+            childDicNode->advanceDigraphIndex();
+            if (childDicNode->getNodeCodePoint() == codePoint) {
+                childDicNode->advanceDigraphIndex();
+                outDicNodes->emplace_back(*childDicNode);
+            }
+        }
+    }
+}
+
+} // namespace latinime
diff --git a/native/jni/src/suggest/core/dictionary/dictionary_utils.h b/native/jni/src/suggest/core/dictionary/dictionary_utils.h
new file mode 100644
index 0000000..358ebf6
--- /dev/null
+++ b/native/jni/src/suggest/core/dictionary/dictionary_utils.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (C) 2014 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LATINIME_DICTIONARY_UTILS_H
+#define LATINIME_DICTIONARY_UTILS_H
+
+#include <vector>
+
+#include "defines.h"
+
+namespace latinime {
+
+class DictionaryStructureWithBufferPolicy;
+class DicNode;
+
+class DictionaryUtils {
+ public:
+    static int getMaxProbabilityOfExactMatches(
+            const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
+            const int *const codePoints, const int codePointCount);
+
+ private:
+    DISALLOW_IMPLICIT_CONSTRUCTORS(DictionaryUtils);
+
+    static void processChildDicNodes(
+            const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
+            const int inputCodePoint, const DicNode *const parentDicNode,
+            std::vector<DicNode> *const outDicNodes);
+};
+} // namespace latinime
+#endif // LATINIME_DICTIONARY_UTILS_H
diff --git a/tests/src/com/android/inputmethod/latin/BinaryDictionaryTests.java b/tests/src/com/android/inputmethod/latin/BinaryDictionaryTests.java
index ccede0e..55b794c 100644
--- a/tests/src/com/android/inputmethod/latin/BinaryDictionaryTests.java
+++ b/tests/src/com/android/inputmethod/latin/BinaryDictionaryTests.java
@@ -1472,4 +1472,33 @@
         assertEquals(bigramProbability,
                 binaryDictionary.getNgramProbability(prevWordsInfoStartOfSentence, "bbb"));
     }
+
+    public void testGetMaxFrequencyOfExactMatches() {
+        for (final int formatVersion : DICT_FORMAT_VERSIONS) {
+            testGetMaxFrequencyOfExactMatches(formatVersion);
+        }
+    }
+
+    private void testGetMaxFrequencyOfExactMatches(final int formatVersion) {
+        File dictFile = null;
+        try {
+            dictFile = createEmptyDictionaryAndGetFile("TestBinaryDictionary", formatVersion);
+        } catch (IOException e) {
+            fail("IOException while writing an initial dictionary : " + e);
+        }
+        final BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(),
+                0 /* offset */, dictFile.length(), true /* useFullEditDistance */,
+                Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */);
+        addUnigramWord(binaryDictionary, "abc", 10);
+        addUnigramWord(binaryDictionary, "aBc", 15);
+        assertEquals(15, binaryDictionary.getMaxFrequencyOfExactMatches("abc"));
+        addUnigramWord(binaryDictionary, "ab'c", 20);
+        assertEquals(20, binaryDictionary.getMaxFrequencyOfExactMatches("abc"));
+        addUnigramWord(binaryDictionary, "a-b-c", 25);
+        assertEquals(25, binaryDictionary.getMaxFrequencyOfExactMatches("abc"));
+        addUnigramWord(binaryDictionary, "ab-'-'-'-c", 30);
+        assertEquals(30, binaryDictionary.getMaxFrequencyOfExactMatches("abc"));
+        addUnigramWord(binaryDictionary, "ab c", 255);
+        assertEquals(30, binaryDictionary.getMaxFrequencyOfExactMatches("abc"));
+    }
 }