Implement LanguageModelDictContent.getWordProbability().

Bug: 14425059
Change-Id: I290a05cee6f341caa25fb222892505529cef1eb7
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
index 88982e5..df3daa8 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_policy.cpp
@@ -354,7 +354,7 @@
     }
     bool addedNewBigram = false;
     const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
-    if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::fromObject(&prevWordPtNodePos),
+    if (mUpdatingHelper.addNgramEntry(PtNodePosArrayView::singleElementView(&prevWordPtNodePos),
             wordPos, bigramProperty, &addedNewBigram)) {
         if (addedNewBigram) {
             mBigramCount++;
@@ -396,7 +396,7 @@
     }
     const int prevWordPtNodePos = getTerminalPtNodePosFromWordId(prevWordIds[0]);
     if (mUpdatingHelper.removeNgramEntry(
-            PtNodePosArrayView::fromObject(&prevWordPtNodePos), wordPos)) {
+            PtNodePosArrayView::singleElementView(&prevWordPtNodePos), wordPos)) {
         mBigramCount--;
         return true;
     } else {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index d5749e9..f54bb15 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -38,6 +38,40 @@
             0 /* nextLevelBitmapEntryIndex */, outNgramCount);
 }
 
+int LanguageModelDictContent::getWordProbability(const WordIdArrayView prevWordIds,
+        const int wordId) const {
+    int bitmapEntryIndices[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
+    bitmapEntryIndices[0] = mTrieMap.getRootBitmapEntryIndex();
+    int maxLevel = 0;
+    for (size_t i = 0; i < prevWordIds.size(); ++i) {
+        const int nextBitmapEntryIndex =
+                mTrieMap.get(prevWordIds[i], bitmapEntryIndices[i]).mNextLevelBitmapEntryIndex;
+        if (nextBitmapEntryIndex == TrieMap::INVALID_INDEX) {
+            break;
+        }
+        maxLevel = i + 1;
+        bitmapEntryIndices[i + 1] = nextBitmapEntryIndex;
+    }
+
+    for (int i = maxLevel; i >= 0; --i) {
+        const TrieMap::Result result = mTrieMap.get(wordId, bitmapEntryIndices[i]);
+        if (!result.mIsValid) {
+            continue;
+        }
+        const int probability =
+                ProbabilityEntry::decode(result.mValue, mHasHistoricalInfo).getProbability();
+        if (mHasHistoricalInfo) {
+            return std::min(
+                    probability + ForgettingCurveUtils::getProbabilityBiasForNgram(i + 1 /* n */),
+                    MAX_PROBABILITY);
+        } else {
+            return probability;
+        }
+    }
+    // Cannot find the word.
+    return NOT_A_PROBABILITY;
+}
+
 ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
         const WordIdArrayView prevWordIds, const int wordId) const {
     const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index aa612e3..4e0b470 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -128,6 +128,8 @@
             const LanguageModelDictContent *const originalContent,
             int *const outNgramCount);
 
+    int getWordProbability(const WordIdArrayView prevWordIds, const int wordId) const;
+
     ProbabilityEntry getProbabilityEntry(const int wordId) const {
         return getNgramProbabilityEntry(WordIdArrayView(), wordId);
     }
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 6de3e5a..308c355 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -115,24 +115,12 @@
 
 int Ver4PatriciaTriePolicy::getProbabilityOfWordInContext(const int *const prevWordIds,
         const int wordId, MultiBigramMap *const multiBigramMap) const {
-    // TODO: Quit using MultiBigramMap.
     if (wordId == NOT_A_WORD_ID) {
         return NOT_A_PROBABILITY;
     }
-    const int ptNodePos =
-            mBuffers->getTerminalPositionLookupTable()->getTerminalPtNodePosition(wordId);
-    const PtNodeParams ptNodeParams(mNodeReader.fetchPtNodeParamsInBufferFromPtNodePos(ptNodePos));
-    if (multiBigramMap) {
-        return multiBigramMap->getBigramProbability(this /* structurePolicy */, prevWordIds,
-                wordId, ptNodeParams.getProbability());
-    }
-    if (prevWordIds) {
-        const int probability = getProbabilityOfWord(prevWordIds, wordId);
-        if (probability != NOT_A_PROBABILITY) {
-            return probability;
-        }
-    }
-    return getProbability(ptNodeParams.getProbability(), NOT_A_PROBABILITY);
+    // TODO: Support n-gram.
+    return mBuffers->getLanguageModelDictContent()->getWordProbability(
+            WordIdArrayView::singleElementView(prevWordIds), wordId);
 }
 
 int Ver4PatriciaTriePolicy::getProbability(const int unigramProbability,
@@ -166,7 +154,7 @@
         // TODO: Support n-gram.
         const ProbabilityEntry probabilityEntry =
                 mBuffers->getLanguageModelDictContent()->getNgramProbabilityEntry(
-                        IntArrayView::fromObject(prevWordIds), wordId);
+                        IntArrayView::singleElementView(prevWordIds), wordId);
         if (!probabilityEntry.isValid()) {
             return NOT_A_PROBABILITY;
         }
@@ -194,7 +182,7 @@
     // TODO: Support n-gram.
     const auto languageModelDictContent = mBuffers->getLanguageModelDictContent();
     for (const auto entry : languageModelDictContent->getProbabilityEntries(
-            WordIdArrayView::fromObject(prevWordIds))) {
+            WordIdArrayView::singleElementView(prevWordIds))) {
         const ProbabilityEntry &probabilityEntry = entry.getProbabilityEntry();
         const int probability = probabilityEntry.hasHistoricalInfo() ?
                 ForgettingCurveUtils::decodeProbability(
@@ -511,7 +499,7 @@
     // Fetch bigram information.
     // TODO: Support n-gram.
     std::vector<BigramProperty> bigrams;
-    const WordIdArrayView prevWordIds = WordIdArrayView::fromObject(&wordId);
+    const WordIdArrayView prevWordIds = WordIdArrayView::singleElementView(&wordId);
     int bigramWord1CodePoints[MAX_WORD_LENGTH];
     for (const auto entry : mBuffers->getLanguageModelDictContent()->getProbabilityEntries(
             prevWordIds)) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
index 9910777..313eb6b 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h
@@ -48,6 +48,11 @@
     static bool needsToDecay(const bool mindsBlockByDecay, const int unigramCount,
             const int bigramCount, const HeaderPolicy *const headerPolicy);
 
+    // TODO: Improve probability computation method and remove this.
+    static int getProbabilityBiasForNgram(const int n) {
+        return (n - 1) * MULTIPLIER_TWO_IN_PROBABILITY_SCALE;
+    }
+
     AK_FORCE_INLINE static int getUnigramCountHardLimit(const int maxUnigramCount) {
         return static_cast<int>(static_cast<float>(maxUnigramCount)
                 * UNIGRAM_COUNT_HARD_LIMIT_WEIGHT);
diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h
index c9c3b21..08256bd 100644
--- a/native/jni/src/utils/int_array_view.h
+++ b/native/jni/src/utils/int_array_view.h
@@ -61,9 +61,9 @@
         return IntArrayView(array, N);
     }
 
-    // Returns a view that points one int object. Does not take ownership of the given object.
-    AK_FORCE_INLINE static IntArrayView fromObject(const int *const object) {
-        return IntArrayView(object, 1);
+    // Returns a view that points one int object.
+    AK_FORCE_INLINE static IntArrayView singleElementView(const int *const ptr) {
+        return IntArrayView(ptr, 1);
     }
 
     AK_FORCE_INLINE int operator[](const size_t index) const {
diff --git a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
index ca8d56f..e6f0353 100644
--- a/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
+++ b/native/jni/tests/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content_test.cpp
@@ -26,28 +26,28 @@
 namespace {
 
 TEST(LanguageModelDictContentTest, TestUnigramProbability) {
-    LanguageModelDictContent LanguageModelDictContent(false /* useHistoricalInfo */);
+    LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
 
     const int flag = 0xFF;
     const int probability = 10;
     const int wordId = 100;
     const ProbabilityEntry probabilityEntry(flag, probability);
-    LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
+    languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
     const ProbabilityEntry entry =
-            LanguageModelDictContent.getProbabilityEntry(wordId);
+            languageModelDictContent.getProbabilityEntry(wordId);
     EXPECT_EQ(flag, entry.getFlags());
     EXPECT_EQ(probability, entry.getProbability());
 
     // Remove
-    EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
-    EXPECT_FALSE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid());
-    EXPECT_FALSE(LanguageModelDictContent.removeProbabilityEntry(wordId));
-    EXPECT_TRUE(LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
-    EXPECT_TRUE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid());
+    EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
+    EXPECT_FALSE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
+    EXPECT_FALSE(languageModelDictContent.removeProbabilityEntry(wordId));
+    EXPECT_TRUE(languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
+    EXPECT_TRUE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
 }
 
 TEST(LanguageModelDictContentTest, TestUnigramProbabilityWithHistoricalInfo) {
-    LanguageModelDictContent LanguageModelDictContent(true /* useHistoricalInfo */);
+    LanguageModelDictContent languageModelDictContent(true /* useHistoricalInfo */);
 
     const int flag = 0xF0;
     const int timestamp = 0x3FFFFFFF;
@@ -56,19 +56,19 @@
     const int wordId = 100;
     const HistoricalInfo historicalInfo(timestamp, level, count);
     const ProbabilityEntry probabilityEntry(flag, &historicalInfo);
-    LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
-    const ProbabilityEntry entry = LanguageModelDictContent.getProbabilityEntry(wordId);
+    languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
+    const ProbabilityEntry entry = languageModelDictContent.getProbabilityEntry(wordId);
     EXPECT_EQ(flag, entry.getFlags());
     EXPECT_EQ(timestamp, entry.getHistoricalInfo()->getTimeStamp());
     EXPECT_EQ(level, entry.getHistoricalInfo()->getLevel());
     EXPECT_EQ(count, entry.getHistoricalInfo()->getCount());
 
     // Remove
-    EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
-    EXPECT_FALSE(LanguageModelDictContent.getProbabilityEntry(wordId).isValid());
-    EXPECT_FALSE(LanguageModelDictContent.removeProbabilityEntry(wordId));
-    EXPECT_TRUE(LanguageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
-    EXPECT_TRUE(LanguageModelDictContent.removeProbabilityEntry(wordId));
+    EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
+    EXPECT_FALSE(languageModelDictContent.getProbabilityEntry(wordId).isValid());
+    EXPECT_FALSE(languageModelDictContent.removeProbabilityEntry(wordId));
+    EXPECT_TRUE(languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry));
+    EXPECT_TRUE(languageModelDictContent.removeProbabilityEntry(wordId));
 }
 
 TEST(LanguageModelDictContentTest, TestIterateProbabilityEntry) {
@@ -89,5 +89,31 @@
     EXPECT_TRUE(wordIdSet.empty());
 }
 
+TEST(LanguageModelDictContentTest, TestGetWordProbability) {
+    LanguageModelDictContent languageModelDictContent(false /* useHistoricalInfo */);
+
+    const int flag = 0xFF;
+    const int probability = 10;
+    const int bigramProbability = 20;
+    const int trigramProbability = 30;
+    const int wordId = 100;
+    const int prevWordIdArray[] = { 1, 2 };
+    const WordIdArrayView prevWordIds = WordIdArrayView::fromFixedSizeArray(prevWordIdArray);
+
+    const ProbabilityEntry probabilityEntry(flag, probability);
+    languageModelDictContent.setProbabilityEntry(wordId, &probabilityEntry);
+    const ProbabilityEntry bigramProbabilityEntry(flag, bigramProbability);
+    languageModelDictContent.setProbabilityEntry(prevWordIds[0], &probabilityEntry);
+    languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1), wordId,
+            &bigramProbabilityEntry);
+    EXPECT_EQ(bigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
+    const ProbabilityEntry trigramProbabilityEntry(flag, trigramProbability);
+    languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(1),
+            prevWordIds[1], &probabilityEntry);
+    languageModelDictContent.setNgramProbabilityEntry(prevWordIds.limit(2), wordId,
+            &trigramProbabilityEntry);
+    EXPECT_EQ(trigramProbability, languageModelDictContent.getWordProbability(prevWordIds, wordId));
+}
+
 }  // namespace
 }  // namespace latinime
diff --git a/native/jni/tests/utils/int_array_view_test.cpp b/native/jni/tests/utils/int_array_view_test.cpp
index 161df2f..93bad58 100644
--- a/native/jni/tests/utils/int_array_view_test.cpp
+++ b/native/jni/tests/utils/int_array_view_test.cpp
@@ -52,7 +52,7 @@
 
 TEST(IntArrayViewTest, TestConstructFromObject) {
     const int object = 10;
-    const auto intArrayView = IntArrayView::fromObject(&object);
+    const auto intArrayView = IntArrayView::singleElementView(&object);
     EXPECT_EQ(1u, intArrayView.size());
     EXPECT_EQ(object, intArrayView[0]);
 }