Truncate entries in language model dict content.

Bug: 14425059

Change-Id: I023c1d5109a2c43fcea3bb11a0fd7198c82891ba
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index a66cfef..ea2d24e 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -16,6 +16,9 @@
 
 #include "suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h"
 
+#include <algorithm>
+#include <cstring>
+
 #include "suggest/policyimpl/dictionary/utils/forgetting_curve_utils.h"
 
 namespace latinime {
@@ -68,6 +71,19 @@
     return mTrieMap.remove(wordId, bitmapEntryIndex);
 }
 
+bool LanguageModelDictContent::truncateEntries(const int *const entryCounts,
+        const int *const maxEntryCounts, const HeaderPolicy *const headerPolicy) {
+    for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
+        if (entryCounts[i] <= maxEntryCounts[i]) {
+            continue;
+        }
+        if (!turncateEntriesInSpecifiedLevel(headerPolicy, maxEntryCounts[i], i)) {
+            return false;
+        }
+    }
+    return true;
+}
+
 bool LanguageModelDictContent::runGCInner(
         const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
         const TrieMap::TrieMapRange trieMapRange,
@@ -162,4 +178,87 @@
     return true;
 }
 
+bool LanguageModelDictContent::turncateEntriesInSpecifiedLevel(
+        const HeaderPolicy *const headerPolicy, const int maxEntryCount, const int targetLevel) {
+    std::vector<int> prevWordIds;
+    std::vector<EntryInfoToTurncate> entryInfoVector;
+    if (!getEntryInfo(headerPolicy, targetLevel, mTrieMap.getRootBitmapEntryIndex(),
+            &prevWordIds, &entryInfoVector)) {
+        return false;
+    }
+    if (static_cast<int>(entryInfoVector.size()) <= maxEntryCount) {
+        return true;
+    }
+    const int entryCountToRemove = static_cast<int>(entryInfoVector.size()) - maxEntryCount;
+    std::partial_sort(entryInfoVector.begin(), entryInfoVector.begin() + entryCountToRemove,
+            entryInfoVector.end(),
+            EntryInfoToTurncate::Comparator());
+    for (int i = 0; i < entryCountToRemove; ++i) {
+        const EntryInfoToTurncate &entryInfo = entryInfoVector[i];
+        if (!removeNgramProbabilityEntry(
+                WordIdArrayView(entryInfo.mPrevWordIds, entryInfo.mEntryLevel), entryInfo.mKey)) {
+            return false;
+        }
+    }
+    return true;
+}
+
+bool LanguageModelDictContent::getEntryInfo(const HeaderPolicy *const headerPolicy,
+        const int targetLevel, const int bitmapEntryIndex,  std::vector<int> *const prevWordIds,
+        std::vector<EntryInfoToTurncate> *const outEntryInfo) const {
+    const int currentLevel = prevWordIds->size();
+    for (const auto &entry : mTrieMap.getEntriesInSpecifiedLevel(bitmapEntryIndex)) {
+        if (currentLevel < targetLevel) {
+            if (!entry.hasNextLevelMap()) {
+                continue;
+            }
+            prevWordIds->push_back(entry.key());
+            if (!getEntryInfo(headerPolicy, targetLevel, entry.getNextLevelBitmapEntryIndex(),
+                    prevWordIds, outEntryInfo)) {
+                return false;
+            }
+            prevWordIds->pop_back();
+            continue;
+        }
+        const ProbabilityEntry probabilityEntry =
+                ProbabilityEntry::decode(entry.value(), mHasHistoricalInfo);
+        const int probability = (mHasHistoricalInfo) ?
+                ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo(),
+                        headerPolicy) : probabilityEntry.getProbability();
+        outEntryInfo->emplace_back(probability,
+                probabilityEntry.getHistoricalInfo()->getTimeStamp(),
+                entry.key(), targetLevel, prevWordIds->data());
+    }
+    return true;
+}
+
+bool LanguageModelDictContent::EntryInfoToTurncate::Comparator::operator()(
+        const EntryInfoToTurncate &left, const EntryInfoToTurncate &right) const {
+    if (left.mProbability != right.mProbability) {
+        return left.mProbability < right.mProbability;
+    }
+    if (left.mTimestamp != right.mTimestamp) {
+        return left.mTimestamp > right.mTimestamp;
+    }
+    if (left.mKey != right.mKey) {
+        return left.mKey < right.mKey;
+    }
+    if (left.mEntryLevel != right.mEntryLevel) {
+        return left.mEntryLevel > right.mEntryLevel;
+    }
+    for (int i = 0; i < left.mEntryLevel; ++i) {
+        if (left.mPrevWordIds[i] != right.mPrevWordIds[i]) {
+            return left.mPrevWordIds[i] < right.mPrevWordIds[i];
+        }
+    }
+    // left and rigth represent the same entry.
+    return false;
+}
+
+LanguageModelDictContent::EntryInfoToTurncate::EntryInfoToTurncate(const int probability,
+        const int timestamp, const int key, const int entryLevel, const int *const prevWordIds)
+        : mProbability(probability), mTimestamp(timestamp), mKey(key), mEntryLevel(entryLevel) {
+    memmove(mPrevWordIds, prevWordIds, mEntryLevel * sizeof(mPrevWordIds[0]));
+}
+
 } // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index 31ee2fe..43b2aab 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -18,6 +18,7 @@
 #define LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H
 
 #include <cstdio>
+#include <vector>
 
 #include "defines.h"
 #include "suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h"
@@ -77,13 +78,43 @@
 
     bool updateAllProbabilityEntries(const HeaderPolicy *const headerPolicy,
             int *const outEntryCounts) {
+        for (int i = 0; i <= MAX_PREV_WORD_COUNT_FOR_N_GRAM; ++i) {
+            outEntryCounts[i] = 0;
+        }
         return updateAllProbabilityEntriesInner(mTrieMap.getRootBitmapEntryIndex(), 0 /* level */,
                 headerPolicy, outEntryCounts);
     }
 
+    // entryCounts should be created by updateAllProbabilityEntries.
+    bool truncateEntries(const int *const entryCounts, const int *const maxEntryCounts,
+            const HeaderPolicy *const headerPolicy);
+
  private:
     DISALLOW_COPY_AND_ASSIGN(LanguageModelDictContent);
 
+    class EntryInfoToTurncate {
+     public:
+        class Comparator {
+         public:
+            bool operator()(const EntryInfoToTurncate &left,
+                    const EntryInfoToTurncate &right) const;
+         private:
+            DISALLOW_ASSIGNMENT_OPERATOR(Comparator);
+        };
+
+        EntryInfoToTurncate(const int probability, const int timestamp, const int key,
+                const int entryLevel, const int *const prevWordIds);
+
+        int mProbability;
+        int mTimestamp;
+        int mKey;
+        int mEntryLevel;
+        int mPrevWordIds[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
+
+     private:
+        DISALLOW_DEFAULT_CONSTRUCTOR(EntryInfoToTurncate);
+    };
+
     TrieMap mTrieMap;
     const bool mHasHistoricalInfo;
 
@@ -94,6 +125,11 @@
     int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
     bool updateAllProbabilityEntriesInner(const int bitmapEntryIndex, const int level,
             const HeaderPolicy *const headerPolicy, int *const outEntryCounts);
+    bool turncateEntriesInSpecifiedLevel(const HeaderPolicy *const headerPolicy,
+            const int maxEntryCount, const int targetLevel);
+    bool getEntryInfo(const HeaderPolicy *const headerPolicy, const int targetLevel,
+            const int bitmapEntryIndex, std::vector<int> *const prevWordIds,
+            std::vector<EntryInfoToTurncate> *const outEntryInfo) const;
 };
 } // namespace latinime
 #endif /* LATINIME_LANGUAGE_MODEL_DICT_CONTENT_H */
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
index 35bc44b..d53575a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
@@ -91,6 +91,21 @@
         AKLOGE("Failed to update probabilities in language model dict content.");
         return false;
     }
+    if (headerPolicy->isDecayingDict()) {
+        int maxEntryCountTable[MAX_PREV_WORD_COUNT_FOR_N_GRAM + 1];
+        maxEntryCountTable[0] = headerPolicy->getMaxUnigramCount();
+        maxEntryCountTable[1] = headerPolicy->getMaxBigramCount();
+        for (size_t i = 2; i < NELEMS(maxEntryCountTable); ++i) {
+            // TODO: Have max n-gram count.
+            maxEntryCountTable[i] = headerPolicy->getMaxBigramCount();
+        }
+        if (!mBuffers->getMutableLanguageModelDictContent()->truncateEntries(entryCountTable,
+                maxEntryCountTable,  headerPolicy)) {
+            AKLOGE("Failed to truncate entries in language model dict content.");
+            return false;
+        }
+    }
+
     DynamicPtReadingHelper readingHelper(&ptNodeReader, &ptNodeArrayReader);
     readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
     DynamicPtGcEventListeners
@@ -193,6 +208,7 @@
     return true;
 }
 
+// TODO: Remove.
 bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
         const Ver4PatriciaTrieNodeReader *const ptNodeReader,
         Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) {
@@ -233,6 +249,7 @@
     return true;
 }
 
+// TODO: Remove.
 bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) {
     const TerminalPositionLookupTable *const terminalPosLookupTable =
             mBuffers->getTerminalPositionLookupTable();