Discard useless unigrams when overflowing.

Bug: 11734037
Change-Id: I5f991dd1f8fa79fd0c442be323d20c76b47ae22e
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
index b60499e..10f9052 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
@@ -121,8 +121,14 @@
     const PatriciaTrieReadingUtils::NodeFlags updatedFlags =
             DynamicPatriciaTrieReadingUtils::updateAndGetFlags(originalFlags, false /* isMoved */,
                     false /* isDeleted */, true /* willBecomeNonTerminal */);
-    int writingPos = toBeUpdatedPtNodeParams->getHeadPos();
+    if (!mBuffers->getMutableTerminalPositionLookupTable()->setTerminalPtNodePosition(
+            toBeUpdatedPtNodeParams->getTerminalId(), NOT_A_DICT_POS /* ptNodePos */)) {
+        AKLOGE("Cannot update terminal position lookup table. terminal id: %d",
+                toBeUpdatedPtNodeParams->getTerminalId());
+        return false;
+    }
     // Update flags.
+    int writingPos = toBeUpdatedPtNodeParams->getHeadPos();
     return DynamicPatriciaTrieWritingUtils::writeFlagsAndAdvancePosition(mTrieBuffer, updatedFlags,
             &writingPos);
 }
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
index 21d009e..77fb41d 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
@@ -17,6 +17,7 @@
 #include "suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h"
 
 #include <cstring>
+#include <queue>
 
 #include "suggest/policyimpl/dictionary/bigram/ver4_bigram_list_policy.h"
 #include "suggest/policyimpl/dictionary/header/header_policy.h"
@@ -97,10 +98,16 @@
             &traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted)) {
         return false;
     }
+    const int unigramCount = traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted
+            .getValidUnigramCount();
     if (headerPolicy->isDecayingDict()
-            && traversePolicyToUpdateUnigramProbabilityAndMarkUselessPtNodesAsDeleted
-                    .getValidUnigramCount() > ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC) {
-        // TODO: Remove more unigrams.
+            && unigramCount > ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC) {
+        if (!turncateUnigrams(&ptNodeReader, &ptNodeWriter,
+                ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC)) {
+            AKLOGE("Cannot remove unigrams. current: %d, max: %d", unigramCount,
+                    ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC);
+            return false;
+        }
     }
 
     readingHelper.initWithPtNodeArrayPos(rootPtNodeArrayPos);
@@ -179,6 +186,42 @@
     return true;
 }
 
+bool Ver4PatriciaTrieWritingHelper::turncateUnigrams(
+        const Ver4PatriciaTrieNodeReader *const ptNodeReader,
+        Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) {
+    const TerminalPositionLookupTable *const terminalPosLookupTable =
+            mBuffers->getTerminalPositionLookupTable();
+    const int nextTerminalId = terminalPosLookupTable->getNextTerminalId();
+    std::priority_queue<DictProbability, std::vector<DictProbability>, DictProbabilityComparator>
+            priorityQueue;
+    for (int i = 0; i < nextTerminalId; ++i) {
+        const int terminalPos = terminalPosLookupTable->getTerminalPtNodePosition(i);
+        if (terminalPos == NOT_A_DICT_POS) {
+            continue;
+        }
+        const ProbabilityEntry probabilityEntry =
+                mBuffers->getProbabilityDictContent()->getProbabilityEntry(i);
+        const int probability = probabilityEntry.hasHistoricalInfo() ?
+                ForgettingCurveUtils::decodeProbability(probabilityEntry.getHistoricalInfo()) :
+                        probabilityEntry.getProbability();
+        priorityQueue.push(DictProbability(terminalPos, probability,
+                probabilityEntry.getHistoricalInfo()->getTimeStamp()));
+    }
+
+    // Delete unigrams.
+    while (static_cast<int>(priorityQueue.size()) > maxUnigramCount) {
+        const int ptNodePos = priorityQueue.top().getDictPos();
+        const PtNodeParams ptNodeParams =
+                ptNodeReader->fetchNodeInfoInBufferFromPtNodePos(ptNodePos);
+        if (!ptNodeWriter->markPtNodeAsWillBecomeNonTerminal(&ptNodeParams)) {
+            AKLOGE("Cannot mark PtNode as willBecomeNonterminal. PtNode pos: %d", ptNodePos);
+            return false;
+        }
+        priorityQueue.pop();
+    }
+    return true;
+}
+
 bool Ver4PatriciaTrieWritingHelper::TraversePolicyToUpdateAllPtNodeFlagsAndTerminalIds
         ::onVisitingPtNode(const PtNodeParams *const ptNodeParams) {
     if (!ptNodeParams->isTerminal()) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
index 82877fd..26eb678 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
@@ -25,6 +25,7 @@
 
 class HeaderPolicy;
 class Ver4DictBuffers;
+class Ver4PatriciaTrieNodeReader;
 class Ver4PatriciaTrieNodeWriter;
 
 class Ver4PatriciaTrieWritingHelper {
@@ -64,10 +65,56 @@
         const TerminalPositionLookupTable::TerminalIdMap *const mTerminalIdMap;
     };
 
+    // For truncateUnigrams().
+    class DictProbability {
+     public:
+        DictProbability(const int dictPos, const int probability, const int timestamp)
+                : mDictPos(dictPos), mProbability(probability), mTimestamp(timestamp) {}
+
+        int getDictPos() const {
+            return mDictPos;
+        }
+
+        int getProbability() const {
+            return mProbability;
+        }
+
+        int getTimestamp() const {
+            return mTimestamp;
+        }
+
+     private:
+        DISALLOW_DEFAULT_CONSTRUCTOR(DictProbability);
+
+        int mDictPos;
+        int mProbability;
+        int mTimestamp;
+    };
+
+    // For truncateUnigrams().
+    class DictProbabilityComparator {
+     public:
+        bool operator()(const DictProbability &left, const DictProbability &right) {
+            if (left.getProbability() != right.getProbability()) {
+                return left.getProbability() > right.getProbability();
+            }
+            if (left.getTimestamp() != right.getTimestamp()) {
+                return left.getTimestamp() < right.getTimestamp();
+            }
+            return left.getDictPos() > right.getDictPos();
+        }
+
+     private:
+        DISALLOW_ASSIGNMENT_OPERATOR(DictProbabilityComparator);
+    };
+
     bool runGC(const int rootPtNodeArrayPos, const HeaderPolicy *const headerPolicy,
             Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount,
             int *const outBigramCount);
 
+    bool turncateUnigrams(const Ver4PatriciaTrieNodeReader *const ptNodeReader,
+            Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount);
+
     Ver4DictBuffers *const mBuffers;
 };
 } // namespace latinime
diff --git a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java
index d77a11f..825b877 100644
--- a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java
+++ b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java
@@ -303,6 +303,7 @@
         BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(),
                 0 /* offset */, dictFile.length(), true /* useFullEditDistance */,
                 Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */);
+        setCurrentTime(binaryDictionary, mCurrentTime);
 
         final int[] codePointSet = CodePointUtils.generateCodePointSet(codePointSetSize, random);
         final ArrayList<String> words = new ArrayList<String>();
@@ -339,7 +340,65 @@
         forcePassingLongTime(binaryDictionary);
         assertEquals(0, Integer.parseInt(binaryDictionary.getPropertyForTests(
                 BinaryDictionary.UNIGRAM_COUNT_QUERY)));
+    }
 
+    public void testOverflowUnigrams() {
+        testOverflowUnigrams(FormatSpec.VERSION4);
+    }
+
+    private void testOverflowUnigrams(final int formatVersion) {
+        final int unigramCount = 20000;
+        final int eachUnigramTypedCount = 5;
+        final int strongUnigramTypedCount = 20;
+        final int weakUnigramTypedCount = 1;
+        final int codePointSetSize = 50;
+        final long seed = System.currentTimeMillis();
+        final Random random = new Random(seed);
+
+        File dictFile = null;
+        try {
+            dictFile = createEmptyDictionaryAndGetFile("TestBinaryDictionary", formatVersion);
+        } catch (IOException e) {
+            fail("IOException while writing an initial dictionary : " + e);
+        }
+        BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(),
+                0 /* offset */, dictFile.length(), true /* useFullEditDistance */,
+                Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */);
+        setCurrentTime(binaryDictionary, mCurrentTime);
+        final int[] codePointSet = CodePointUtils.generateCodePointSet(codePointSetSize, random);
+
+        final String strong = "strong";
+        final String weak = "weak";
+        for (int j = 0; j < strongUnigramTypedCount; j++) {
+            addUnigramWord(binaryDictionary, strong, DUMMY_PROBABILITY);
+        }
+        for (int j = 0; j < weakUnigramTypedCount; j++) {
+            addUnigramWord(binaryDictionary, weak, DUMMY_PROBABILITY);
+        }
+        assertTrue(binaryDictionary.isValidWord(strong));
+        assertTrue(binaryDictionary.isValidWord(weak));
+
+        for (int i = 0; i < unigramCount; i++) {
+            final String word = CodePointUtils.generateWord(random, codePointSet);
+            for (int j = 0; j < eachUnigramTypedCount; j++) {
+                addUnigramWord(binaryDictionary, word, DUMMY_PROBABILITY);
+            }
+            if (binaryDictionary.needsToRunGC(true /* mindsBlockByGC */)) {
+                final int unigramCountBeforeGC =
+                        Integer.parseInt(binaryDictionary.getPropertyForTests(
+                                BinaryDictionary.UNIGRAM_COUNT_QUERY));
+                assertTrue(binaryDictionary.isValidWord(strong));
+                assertTrue(binaryDictionary.isValidWord(weak));
+                binaryDictionary.flushWithGC();
+                final int unigramCountAfterGC =
+                        Integer.parseInt(binaryDictionary.getPropertyForTests(
+                                BinaryDictionary.UNIGRAM_COUNT_QUERY));
+                assertTrue(unigramCountBeforeGC > unigramCountAfterGC);
+                assertFalse(binaryDictionary.isValidWord(weak));
+                assertTrue(binaryDictionary.isValidWord(strong));
+                break;
+            }
+        }
     }
 
     public void testAddManyBigramsToDecayingDict() {
@@ -363,6 +422,7 @@
         BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(),
                 0 /* offset */, dictFile.length(), true /* useFullEditDistance */,
                 Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */);
+        setCurrentTime(binaryDictionary, mCurrentTime);
 
         final int[] codePointSet = CodePointUtils.generateCodePointSet(codePointSetSize, random);
         final ArrayList<String> words = new ArrayList<String>();