Update probabilities in language model dict content for GC.

Bug: 14425059
Change-Id: I354408afd8e5c1955ff0acea3d0243d628fe3843
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index bbcea2e..a66cfef 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -16,6 +16,8 @@
 
 #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"
 
+#include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
+
 namespace latinime {
 
 bool LanguageModelDictContent::save(FILE *const file) const {
@@ -118,4 +120,46 @@
     return bitmapEntryIndex;
 }
 
+bool LanguageModelDictContent::updateAllProbabilityEntriesInner(const int bitmapEntryIndex,
+        const int level, const HeaderPolicy *const headerPolicy, int *const outEntryCounts) {
+    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
+        if (level > MAX_PREV_WORD_COUNT_FOR_N_GRAM) {
+            AKLOGE("Invalid level. level: %d, MAX_PREV_WORD_COUNT_FOR_N_GRAM: %d.",
+                    level, MAX_PREV_WORD_COUNT_FOR_N_GRAM);
+            return false;
+        }
+        const ProbabilityEntry probabilityEntry =
+                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
+        if (mHasHistoricalInfo && !probabilityEntry.representsBeginningOfSentence()) {
+            const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
+                    probabilityEntry.getHistoricalInfo(), headerPolicy);
+            if (ForgettingCurveUtils::needsToKeep(&historicalInfo, headerPolicy)) {
+                // Update the entry.
+                const ProbabilityEntry updatedEntry(probabilityEntry.getFlags(), &historicalInfo);
+                if (!mTrieMap.put(entry.key(), updatedEntry.encode(mHasHistoricalInfo),
+                        bitmapEntryIndex)) {
+                    return false;
+                }
+            } else {
+                // Remove the entry.
+                if (!mTrieMap.remove(entry.key(), bitmapEntryIndex)) {
+                    return false;
+                }
+                continue;
+            }
+        }
+        if (!probabilityEntry.representsBeginningOfSentence()) {
+            outEntryCounts[level] += 1;
+        }
+        if (!entry.hasNextLevelMap()) {
+            continue;
+        }
+        if (!updateAllProbabilityEntriesInner(entry.getNextLevelBitmapEntryIndex(), level + 1,
+                headerPolicy, outEntryCounts)) {
+            return false;
+        }
+    }
+    return true;
+}
+
 } // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index bd07f2f..31ee2fe 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -29,6 +29,8 @@
 
 namespace latinime {
 
+class HeaderPolicy;
+
 /**
  * Class representing language model.
  *
@@ -73,6 +75,12 @@
 
     bool removeNgramProbabilityEntry(const WordIdArrayView prevWordIds, const int wordId);
 
+    bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
+            int *const outEntryCounts) {
+        return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
+                headerPolicy, outEntryCounts);
+    }
+
  private:
     DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
 
@@ -84,6 +92,8 @@
             int *const outNgramCount);
     int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
     int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
+    bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
+            const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
 };
 } // namespace latinime
 #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
index fb6840b..b7c31bf 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
@@ -161,29 +161,15 @@
     const ProbabilityEntry originalProbabilityEntry =
             mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
                     toBeUpdatedPtNodeParams->getTerminalId());
-    if (originalProbabilityEntry.hasHistoricalInfo()) {
-        const HistoricalInfo historicalInfo = ForgettingCurveUtils::createHistoricalInfoToSave(
-                originalProbabilityEntry.getHistoricalInfo(), mHeaderPolicy);
-        const ProbabilityEntry probabilityEntry(originalProbabilityEntry.getFlags(),
-                &historicalInfo);
-        if (!mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
-                toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry)) {
-            AKLOGE("Cannot write updated probability entry. terminalId: %d",
-                    toBeUpdatedPtNodeParams->getTerminalId());
-            return false;
-        }
-        const bool isValid = ForgettingCurveUtils::needsToKeep(&historicalInfo, mHeaderPolicy);
-        if (!isValid) {
-            if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
-                AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
-                return false;
-            }
-        }
-        *outNeedsToKeepPtNode = isValid;
-    } else {
-        // No need to update probability.
+    if (originalProbabilityEntry.isValid()) {
         *outNeedsToKeepPtNode = true;
+        return true;
     }
+    if (!markPtNodeAsWillBecomeNonTerminal(toBeUpdatedPtNodeParams)) {
+        AKLOGE("Cannot mark PtNode as willBecomeNonTerminal.");
+        return false;
+    }
+    *outNeedsToKeepPtNode = false;
     return true;
 }
 
@@ -380,6 +366,7 @@
             isTerminal, ptNodeParams->getCodePointCount() > 1 /* hasMultipleChars */);
 }
 
+// TODO: Move probability handling code to LanguageModelDictContent.
 const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
         const ProbabilityEntry *const originalProbabilityEntry,
         const ProbabilityEntry *const probabilityEntry) const {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
index 4220312..35bc44b 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
@@ -85,6 +85,12 @@
             mBuffers, headerPolicy, &ptNodeReader, &ptNodeArrayReader, &bigramPolicy,
             &shortcutPolicy);
 
+    int entryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
+    if (!mBuffers->getMutableLanguageModelDictContent()->updateAllProbabilityEntries(headerPolicy,
+            entryCountTable)) {
+        AKLOGE("Failed to update probabilities in language model dict content.");
+        return false;
+    }
     DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
     readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
     DynamicPtGcEventListeners
diff --git a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
index 6d91790..c2aeac2 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/utils/trie_map.h
@@ -84,6 +84,10 @@
                 return mValue;
             }
 
+            AK_FORCE_INLINE int getNextLevelBitmapEntryIndex() const {
+                return mNextLevelBitmapEntryIndex;
+            }
+
          private:
             const TrieMap *const mTrieMap;
             const int mKey;