Use MultiBigramMap in structure policy.

Bug: 14425059
Change-Id: I4d78da4839ef177e0223e6e5bcf0ebd7315c3099
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
index 9f03e30..19f92cc 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
+++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.cpp
@@ -18,7 +18,6 @@
 
 #include "suggest/core/dicnode/dic_node.h"
 #include "suggest/core/dicnode/dic_node_vector.h"
-#include "suggest/core/dictionary/multi_bigram_map.h"
 #include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
 
 namespace latinime {
@@ -73,25 +72,12 @@
     if (dicNode->hasMultipleWords() && !dicNode->isValidMultipleWordSuggestion()) {
         return static_cast<float>(MAX_VALUE_FOR_WEIGHTING);
     }
-    const int probability = getBigramNodeProbability(dictionaryStructurePolicy, dicNode,
-            multiBigramMap);
+    const int probability = dictionaryStructurePolicy->getProbabilityOfWordInContext(
+            dicNode->getPrevWordIds(), dicNode->getWordId(), multiBigramMap);
     // TODO: This equation to calculate the improbability looks unreasonable.  Investigate this.
     const float cost = static_cast<float>(MAX_PROBABILITY - probability)
             / static_cast<float>(MAX_PROBABILITY);
     return cost;
 }
 
-/* static */ int DicNodeUtils::getBigramNodeProbability(
-        const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
-        const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) {
-    const int unigramProbability = dicNode->getUnigramProbability();
-    if (multiBigramMap) {
-        const int *const prevWordIds = dicNode->getPrevWordIds();
-        return multiBigramMap->getBigramProbability(dictionaryStructurePolicy,
-                prevWordIds, dicNode->getWordId(), unigramProbability);
-    }
-    return dictionaryStructurePolicy->getProbability(unigramProbability,
-            NOT_A_PROBABILITY);
-}
-
 } // namespace latinime
diff --git a/native/jni/src/suggest/core/dicnode/dic_node_utils.h b/native/jni/src/suggest/core/dicnode/dic_node_utils.h
index 56ff6e3..961a1c2 100644
--- a/native/jni/src/suggest/core/dicnode/dic_node_utils.h
+++ b/native/jni/src/suggest/core/dicnode/dic_node_utils.h
@@ -46,10 +46,6 @@
     DISALLOW_IMPLICIT_CONSTRUCTORS(DicNodeUtils);
     // Max number of bigrams to look up
     static const int MAX_BIGRAMS_CONSIDERED_PER_CONTEXT = 500;
-
-    static int getBigramNodeProbability(
-            const DictionaryStructureWithBufferPolicy *const dictionaryStructurePolicy,
-            const DicNode *const dicNode, MultiBigramMap *const multiBigramMap);
 };
 } // namespace latinime
 #endif // LATINIME_DIC_NODE_UTILS_H
diff --git a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
index aeeb66f..4e55418 100644
--- a/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
+++ b/native/jni/src/suggest/core/policy/dictionary_structure_with_buffer_policy.h
@@ -29,6 +29,7 @@
 class DicNode;
 class DicNodeVector;
 class DictionaryHeaderStructurePolicy;
+class MultiBigramMap;
 class NgramListener;
 class PrevWordsInfo;
 class UnigramProperty;
@@ -56,6 +57,10 @@
     virtual int getWordId(const CodePointArrayView wordCodePoints,
             const bool forceLowerCaseSearch) const = 0;
 
+    virtual int getProbabilityOfWordInContext(const int *const prevWordIds, const int wordId,
+            MultiBigramMap *const multiBigramMap) const = 0;
+
+    // TODO: Remove
     virtual int getProbability(const int unigramProbability, const int bigramProbability) const = 0;
 
     virtual int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const = 0;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index 6480374..88982e5 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -28,6 +28,7 @@
 
 #include "suggest/core/dicnode/dic_node.h"
 #include "suggest/core/dicnode/dic_node_vector.h"
+#include "suggest/core/dictionary/multi_bigram_map.h"
 #include "suggest/core/dictionary/ngram_listener.h"
 #include "suggest/core/dictionary/property/bigram_property.h"
 #include "suggest/core/dictionary/property/unigram_property.h"
@@ -117,6 +118,26 @@
     return getWordIdFromTerminalPtNodePos(ptNodePos);
 }
 
+int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds,
+        const int wordId, MultiBigramMap *const multiBigramMap) const {
+    if (wordId == NOT_A_WORD_ID) {
+        return NOT_A_PROBABILITY;
+    }
+    const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
+    const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
+    if (multiBigramMap) {
+        return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds,
+                wordId, ptNodeParams.getProbability());
+    }
+    if (prevWordIds) {
+        const int probability = getProbabilityOfWord(prevWordIds, wordId);
+        if (probability != NOT_A_PROBABILITY) {
+            return probability;
+        }
+    }
+    return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
+}
+
 int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
         const int bigramProbability) const {
     if (mHeaderPolicy->isDecayingDict()) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
index 562c219..06d7041 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.h
@@ -91,6 +91,9 @@
 
     int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
 
+    int getProbabilityOfWordInContext(const int *const prevWordIds, const int wordId,
+            MultiBigramMap *const multiBigramMap) const;
+
     int getProbability(const int unigramProbability, const int bigramProbability) const;
 
     int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
index e0406ab..80bbf47 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.cpp
@@ -21,6 +21,7 @@
 #include "suggest/core/dicnode/dic_node.h"
 #include "suggest/core/dicnode/dic_node_vector.h"
 #include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h"
+#include "suggest/core/dictionary/multi_bigram_map.h"
 #include "suggest/core/dictionary/ngram_listener.h"
 #include "suggest/core/session/prev_words_info.h"
 #include "suggest/policyimpl/dictionary/structure/pt_common/dynamic_pt_reading_helper.h"
@@ -281,6 +282,27 @@
     return getWordIdFromTerminalPtNodePos(ptNodePos);
 }
 
+int PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds,
+        const int wordId, MultiBigramMap *const multiBigramMap) const {
+    if (wordId == NOT_A_WORD_ID) {
+        return NOT_A_PROBABILITY;
+    }
+    const int ptNodePos = getTerminalPtNodePosFromWordId(wordId);
+    const PtNodeParams ptNodeParams =
+            mPtNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos);
+    if (multiBigramMap) {
+        return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds,
+                wordId, ptNodeParams.getProbability());
+    }
+    if (prevWordIds) {
+        const int bigramProbability = getProbabilityOfWord(prevWordIds, wordId);
+        if (bigramProbability != NOT_A_PROBABILITY) {
+            return bigramProbability;
+        }
+    }
+    return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
+}
+
 int PatriciaTriePolicy::getProbability(const int unigramProbability,
         const int bigramProbability) const {
     // Due to space constraints, the probability for bigrams is approximate - the lower the unigram
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
index 66df527..a2d6b6f 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v2/patricia_trie_policy.h
@@ -66,6 +66,9 @@
 
     int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
 
+    int getProbabilityOfWordInContext(const int *const prevWordIds, const int wordId,
+            MultiBigramMap *const multiBigramMap) const;
+
     int getProbability(const int unigramProbability, const int bigramProbability) const;
 
     int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 466c499..6de3e5a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -20,6 +20,7 @@
 
 #include "suggest/core/dicnode/dic_node.h"
 #include "suggest/core/dicnode/dic_node_vector.h"
+#include "suggest/core/dictionary/multi_bigram_map.h"
 #include "suggest/core/dictionary/ngram_listener.h"
 #include "suggest/core/dictionary/property/bigram_property.h"
 #include "suggest/core/dictionary/property/unigram_property.h"
@@ -112,6 +113,28 @@
     return ptNodeParams.getTerminalId();
 }
 
+int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds,
+        const int wordId, MultiBigramMap *const multiBigramMap) const {
+    // TODO: Quit using MultiBigramMap.
+    if (wordId == NOT_A_WORD_ID) {
+        return NOT_A_PROBABILITY;
+    }
+    const int ptNodePos =
+            mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
+    const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
+    if (multiBigramMap) {
+        return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds,
+                wordId, ptNodeParams.getProbability());
+    }
+    if (prevWordIds) {
+        const int probability = getProbabilityOfWord(prevWordIds, wordId);
+        if (probability != NOT_A_PROBABILITY) {
+            return probability;
+        }
+    }
+    return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
+}
+
 int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
         const int bigramProbability) const {
     if (mHeaderPolicy->isDecayingDict()) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
index 0b8eec4..c9df9df 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.h
@@ -68,6 +68,9 @@
 
     int getWordId(const CodePointArrayView wordCodePoints, const bool forceLowerCaseSearch) const;
 
+    int getProbabilityOfWordInContext(const int *const prevWordIds, const int wordId,
+            MultiBigramMap *const multiBigramMap) const;
+
     int getProbability(const int unigramProbability, const int bigramProbability) const;
 
     int getProbabilityOfWord(const int *const prevWordIds, const int wordId) const;