Discard useless bigrams when overflowing.

Bug: 11734037
Change-Id: I339999dffc1902f62e0a6706d969c8f410bb656f
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
index 77fb41d..2ee9483 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.cpp
@@ -102,7 +102,7 @@
             .getValidUnigramCount();
     if (headerPolicy->isDecayingDict()
             && unigramCount > ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC) {
-        if (!turncateUnigrams(&ptNodeReader, &ptNodeWriter,
+        if (!truncateUnigrams(&ptNodeReader, &ptNodeWriter,
                 ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC)) {
             AKLOGE("Cannot remove unigrams. current: %d, max: %d", unigramCount,
                     ForgettingCurveUtils::MAX_UNIGRAM_COUNT_AFTER_GC);
@@ -117,10 +117,14 @@
             &traversePolicyToUpdateBigramProbability)) {
         return false;
     }
+    const int bigramCount = traversePolicyToUpdateBigramProbability.getValidBigramEntryCount();
     if (headerPolicy->isDecayingDict()
-            && traversePolicyToUpdateBigramProbability.getValidBigramEntryCount()
-                    > ForgettingCurveUtils::MAX_BIGRAM_COUNT_AFTER_GC) {
-        // TODO: Remove more bigrams.
+            && bigramCount > ForgettingCurveUtils::MAX_BIGRAM_COUNT_AFTER_GC) {
+        if (!truncateBigrams(ForgettingCurveUtils::MAX_BIGRAM_COUNT_AFTER_GC)) {
+            AKLOGE("Cannot remove bigrams. current: %d, max: %d", bigramCount,
+                    ForgettingCurveUtils::MAX_BIGRAM_COUNT_AFTER_GC);
+            return false;
+        }
     }
 
     // Mapping from positions in mBuffer to positions in bufferToWrite.
@@ -186,7 +190,7 @@
     return true;
 }
 
-bool Ver4PatriciaTrieWritingHelper::turncateUnigrams(
+bool Ver4PatriciaTrieWritingHelper::truncateUnigrams(
         const Ver4PatriciaTrieNodeReader *const ptNodeReader,
         Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount) {
     const TerminalPositionLookupTable *const terminalPosLookupTable =
@@ -222,6 +226,50 @@
     return true;
 }
 
+bool Ver4PatriciaTrieWritingHelper::truncateBigrams(const int maxBigramCount) {
+    const TerminalPositionLookupTable *const terminalPosLookupTable =
+            mBuffers->getTerminalPositionLookupTable();
+    const int nextTerminalId = terminalPosLookupTable->getNextTerminalId();
+    std::priority_queue<DictProbability, std::vector<DictProbability>, DictProbabilityComparator>
+            priorityQueue;
+    BigramDictContent *const bigramDictContent = mBuffers->getMutableBigramDictContent();
+    for (int i = 0; i < nextTerminalId; ++i) {
+        const int bigramListPos = bigramDictContent->getBigramListHeadPos(i);
+        if (bigramListPos == NOT_A_DICT_POS) {
+            continue;
+        }
+        bool hasNext = true;
+        int readingPos = bigramListPos;
+        while (hasNext) {
+            const int entryPos = readingPos;
+            const BigramEntry bigramEntry =
+                    bigramDictContent->getBigramEntryAndAdvancePosition(&readingPos);
+            hasNext = bigramEntry.hasNext();
+            if (!bigramEntry.isValid()) {
+                continue;
+            }
+            const int probability = bigramEntry.hasHistoricalInfo() ?
+                    ForgettingCurveUtils::decodeProbability(bigramEntry.getHistoricalInfo()) :
+                            bigramEntry.getProbability();
+            priorityQueue.push(DictProbability(entryPos, probability,
+                    bigramEntry.getHistoricalInfo()->getTimeStamp()));
+        }
+    }
+
+    // Delete bigrams.
+    while (static_cast<int>(priorityQueue.size()) > maxBigramCount) {
+        const int entryPos = priorityQueue.top().getDictPos();
+        const BigramEntry bigramEntry = bigramDictContent->getBigramEntry(entryPos);
+        const BigramEntry invalidatedBigramEntry = bigramEntry.getInvalidatedEntry();
+        if (!bigramDictContent->writeBigramEntry(&invalidatedBigramEntry, entryPos)) {
+            AKLOGE("Cannot write bigram entry to remove. pos: %d", entryPos);
+            return false;
+        }
+        priorityQueue.pop();
+    }
+    return true;
+}
+
 bool Ver4PatriciaTrieWritingHelper::TraversePolicyToUpdateAllPtNodeFlagsAndTerminalIds
         ::onVisitingPtNode(const PtNodeParams *const ptNodeParams) {
     if (!ptNodeParams->isTerminal()) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
index 26eb678..73e63ef 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_writing_helper.h
@@ -65,7 +65,7 @@
         const TerminalPositionLookupTable::TerminalIdMap *const mTerminalIdMap;
     };
 
-    // For truncateUnigrams().
+    // For truncateUnigrams() and truncateBigrams().
     class DictProbability {
      public:
         DictProbability(const int dictPos, const int probability, const int timestamp)
@@ -91,7 +91,7 @@
         int mTimestamp;
     };
 
-    // For truncateUnigrams().
+    // For truncateUnigrams() and truncateBigrams().
     class DictProbabilityComparator {
      public:
         bool operator()(const DictProbability &left, const DictProbability &right) {
@@ -112,9 +112,11 @@
             Ver4DictBuffers *const buffersToWrite, int *const outUnigramCount,
             int *const outBigramCount);
 
-    bool turncateUnigrams(const Ver4PatriciaTrieNodeReader *const ptNodeReader,
+    bool truncateUnigrams(const Ver4PatriciaTrieNodeReader *const ptNodeReader,
             Ver4PatriciaTrieNodeWriter *const ptNodeWriter, const int maxUnigramCount);
 
+    bool truncateBigrams(const int maxBigramCount);
+
     Ver4DictBuffers *const mBuffers;
 };
 } // namespace latinime
diff --git a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java
index 825b877..c63193a 100644
--- a/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java
+++ b/tests/src/com/android/inputmethod/latin/BinaryDictionaryDecayingTests.java
@@ -474,4 +474,83 @@
         assertEquals(0, Integer.parseInt(binaryDictionary.getPropertyForTests(
                 BinaryDictionary.BIGRAM_COUNT_QUERY)));
     }
+
+    public void testOverflowBigrams() {
+        testOverflowBigrams(FormatSpec.VERSION4);
+    }
+
+    private void testOverflowBigrams(final int formatVersion) {
+        final int bigramCount = 20000;
+        final int unigramCount = 1000;
+        final int unigramTypedCount = 20;
+        final int eachBigramTypedCount = 5;
+        final int strongBigramTypedCount = 20;
+        final int weakBigramTypedCount = 1;
+        final int codePointSetSize = 50;
+        final long seed = System.currentTimeMillis();
+        final Random random = new Random(seed);
+
+        File dictFile = null;
+        try {
+            dictFile = createEmptyDictionaryAndGetFile("TestBinaryDictionary", formatVersion);
+        } catch (IOException e) {
+            fail("IOException while writing an initial dictionary : " + e);
+        }
+        BinaryDictionary binaryDictionary = new BinaryDictionary(dictFile.getAbsolutePath(),
+                0 /* offset */, dictFile.length(), true /* useFullEditDistance */,
+                Locale.getDefault(), TEST_LOCALE, true /* isUpdatable */);
+        setCurrentTime(binaryDictionary, mCurrentTime);
+        final int[] codePointSet = CodePointUtils.generateCodePointSet(codePointSetSize, random);
+
+        final ArrayList<String> words = new ArrayList<String>();
+        for (int i = 0; i < unigramCount; i++) {
+            final String word = CodePointUtils.generateWord(random, codePointSet);
+            words.add(word);
+            for (int j = 0; j < unigramTypedCount; j++) {
+                addUnigramWord(binaryDictionary, word, DUMMY_PROBABILITY);
+            }
+        }
+        final String strong = "strong";
+        final String weak = "weak";
+        final String target = "target";
+        for (int j = 0; j < unigramTypedCount; j++) {
+            addUnigramWord(binaryDictionary, strong, DUMMY_PROBABILITY);
+            addUnigramWord(binaryDictionary, weak, DUMMY_PROBABILITY);
+            addUnigramWord(binaryDictionary, target, DUMMY_PROBABILITY);
+        }
+        binaryDictionary.flushWithGC();
+        for (int j = 0; j < strongBigramTypedCount; j++) {
+            addBigramWords(binaryDictionary, strong, target, DUMMY_PROBABILITY);
+        }
+        for (int j = 0; j < weakBigramTypedCount; j++) {
+            addBigramWords(binaryDictionary, weak, target, DUMMY_PROBABILITY);
+        }
+        assertTrue(binaryDictionary.isValidBigram(strong, target));
+        assertTrue(binaryDictionary.isValidBigram(weak, target));
+
+        for (int i = 0; i < bigramCount; i++) {
+            final int word0Index = random.nextInt(words.size());
+            final String word0 = words.get(word0Index);
+            final int index = random.nextInt(words.size() - 1);
+            final int word1Index = (index >= word0Index) ? index + 1 : index;
+            final String word1 = words.get(word1Index);
+
+            for (int j = 0; j < eachBigramTypedCount; j++) {
+                addBigramWords(binaryDictionary, word0, word1, DUMMY_PROBABILITY);
+            }
+            if (binaryDictionary.needsToRunGC(true /* mindsBlockByGC */)) {
+                final int bigramCountBeforeGC =
+                        Integer.parseInt(binaryDictionary.getPropertyForTests(
+                                BinaryDictionary.BIGRAM_COUNT_QUERY));
+                binaryDictionary.flushWithGC();
+                final int bigramCountAfterGC =
+                        Integer.parseInt(binaryDictionary.getPropertyForTests(
+                                BinaryDictionary.BIGRAM_COUNT_QUERY));
+                assertTrue(bigramCountBeforeGC > bigramCountAfterGC);
+                assertTrue(binaryDictionary.isValidBigram(strong, target));
+                assertFalse(binaryDictionary.isValidBigram(weak, target));
+                break;
+            }
+        }
+    }
 }