Add bigram lookup implementation.

Bug: 5046459
Change-Id: Id2c7686c5da078751ed587e559417e808779aa7a
diff --git a/native/src/bigram_dictionary.cpp b/native/src/bigram_dictionary.cpp
index 6ed4d09..c340c6c 100644
--- a/native/src/bigram_dictionary.cpp
+++ b/native/src/bigram_dictionary.cpp
@@ -21,13 +21,14 @@
 
 #include "bigram_dictionary.h"
 #include "dictionary.h"
+#include "binary_format.h"
 
 namespace latinime {
 
 BigramDictionary::BigramDictionary(const unsigned char *dict, int maxWordLength,
         int maxAlternatives, const bool isLatestDictVersion, const bool hasBigram,
         Dictionary *parentDictionary)
-    : DICT(dict), MAX_WORD_LENGTH(maxWordLength),
+    : DICT(dict + NEW_DICTIONARY_HEADER_SIZE), MAX_WORD_LENGTH(maxWordLength),
     MAX_ALTERNATIVES(maxAlternatives), IS_LATEST_DICT_VERSION(isLatestDictVersion),
     HAS_BIGRAM(hasBigram), mParentDictionary(parentDictionary) {
     if (DEBUG_DICT) {
@@ -82,169 +83,64 @@
     return false;
 }
 
-int BigramDictionary::getBigramAddress(int *pos, bool advance) {
-    int address = 0;
-
-    address += (DICT[*pos] & 0x3F) << 16;
-    address += (DICT[*pos + 1] & 0xFF) << 8;
-    address += (DICT[*pos + 2] & 0xFF);
-
-    if (advance) {
-        *pos += 3;
-    }
-
-    return address;
-}
-
-int BigramDictionary::getBigramFreq(int *pos) {
-    int freq = DICT[(*pos)++] & FLAG_BIGRAM_FREQ;
-
-    return freq;
-}
-
-
+/* Parameters :
+ * prevWord: the word before, the one for which we need to look up bigrams.
+ * prevWordLength: its length.
+ * codes: what user typed, in the same format as for UnigramDictionary::getSuggestions.
+ * codesSize: the size of the codes array.
+ * bigramChars: an array for output, at the same format as outwords for getSuggestions.
+ * bigramFreq: an array to output frequencies.
+ * maxWordLength: the maximum size of a word.
+ * maxBigrams: the maximum number of bigrams fitting in the bigramChars array.
+ * maxAlteratives: unused.
+ * This method returns the number of bigrams this word has, for backward compatibility.
+ * Note: this is not the number of bigrams output in the array, which is the number of
+ * bigrams this word has WHOSE first letter also matches the letter the user typed.
+ * TODO: this may not be a sensible thing to do. It makes sense when the bigrams are
+ * used to match the first letter of the second word, but once the user has typed more
+ * and the bigrams are used to boost unigram result scores, it makes little sense to
+ * reduce their scope to the ones that match the first letter.
+ */
 int BigramDictionary::getBigrams(unsigned short *prevWord, int prevWordLength, int *codes,
         int codesSize, unsigned short *bigramChars, int *bigramFreq, int maxWordLength,
         int maxBigrams, int maxAlternatives) {
+    // TODO: remove unused arguments, and refrain from storing stuff in members of this class
+    // TODO: have "in" arguments before "out" ones, and make out args explicit in the name
     mBigramFreq = bigramFreq;
     mBigramChars = bigramChars;
     mInputCodes = codes;
-    mInputLength = codesSize;
     mMaxBigrams = maxBigrams;
 
-    if (HAS_BIGRAM && IS_LATEST_DICT_VERSION) {
-        int pos = mParentDictionary->getBigramPosition(prevWord, prevWordLength);
-        if (DEBUG_DICT) {
-            LOGI("Pos -> %d", pos);
-        }
-        if (pos < 0) {
-            return 0;
-        }
+    const uint8_t* const root = DICT;
+    int pos = BinaryFormat::getTerminalPosition(root, prevWord, prevWordLength);
 
-        int bigramCount = 0;
-        int bigramExist = (DICT[pos] & FLAG_BIGRAM_READ);
-        if (bigramExist > 0) {
-            int nextBigramExist = 1;
-            while (nextBigramExist > 0 && bigramCount < maxBigrams) {
-                int bigramAddress = getBigramAddress(&pos, true);
-                int frequency = (FLAG_BIGRAM_FREQ & DICT[pos]);
-                // search for all bigrams and store them
-                searchForTerminalNode(bigramAddress, frequency);
-                nextBigramExist = (DICT[pos++] & FLAG_BIGRAM_CONTINUED);
-                bigramCount++;
-            }
-        }
-
-        return bigramCount;
+    if (NOT_VALID_WORD == pos) return 0;
+    const int flags = BinaryFormat::getFlagsAndForwardPointer(root, &pos);
+    if (0 == (flags & UnigramDictionary::FLAG_HAS_BIGRAMS)) return 0;
+    if (0 == (flags & UnigramDictionary::FLAG_HAS_MULTIPLE_CHARS)) {
+        BinaryFormat::getCharCodeAndForwardPointer(root, &pos);
+    } else {
+        pos = BinaryFormat::skipOtherCharacters(root, pos);
     }
-    return 0;
-}
+    pos = BinaryFormat::skipChildrenPosition(flags, pos);
+    pos = BinaryFormat::skipFrequency(flags, pos);
+    int bigramFlags;
+    int bigramCount = 0;
+    do {
+        bigramFlags = BinaryFormat::getFlagsAndForwardPointer(root, &pos);
+        uint16_t bigramBuffer[MAX_WORD_LENGTH];
+        const int bigramPos = BinaryFormat::getAttributeAddressAndForwardPointer(root, bigramFlags,
+                &pos);
+        const int length = BinaryFormat::getWordAtAddress(root, bigramPos, MAX_WORD_LENGTH,
+                bigramBuffer);
 
-void BigramDictionary::searchForTerminalNode(int addressLookingFor, int frequency) {
-    // track word with such address and store it in an array
-    unsigned short word[MAX_WORD_LENGTH];
-
-    int pos;
-    int followDownBranchAddress = DICTIONARY_HEADER_SIZE;
-    bool found = false;
-    char followingChar = ' ';
-    int depth = -1;
-
-    while(!found) {
-        bool followDownAddressSearchStop = false;
-        bool firstAddress = true;
-        bool haveToSearchAll = true;
-
-        if (depth < MAX_WORD_LENGTH && depth >= 0) {
-            word[depth] = (unsigned short) followingChar;
+        if (checkFirstCharacter(bigramBuffer)) {
+            const int frequency = UnigramDictionary::MASK_ATTRIBUTE_FREQUENCY & bigramFlags;
+            addWordBigram(bigramBuffer, length, frequency);
         }
-        pos = followDownBranchAddress; // pos start at count
-        int count = DICT[pos] & 0xFF;
-        if (DEBUG_DICT) {
-            LOGI("count - %d",count);
-        }
-        pos++;
-        for (int i = 0; i < count; i++) {
-            // pos at data
-            pos++;
-            // pos now at flag
-            if (!getFirstBitOfByte(&pos)) { // non-terminal
-                if (!followDownAddressSearchStop) {
-                    int addr = getBigramAddress(&pos, false);
-                    if (addr > addressLookingFor) {
-                        followDownAddressSearchStop = true;
-                        if (firstAddress) {
-                            firstAddress = false;
-                            haveToSearchAll = true;
-                        } else if (!haveToSearchAll) {
-                            break;
-                        }
-                    } else {
-                        followDownBranchAddress = addr;
-                        followingChar = (char)(0xFF & DICT[pos-1]);
-                        if (firstAddress) {
-                            firstAddress = false;
-                            haveToSearchAll = false;
-                        }
-                    }
-                }
-                pos += 3;
-            } else if (getFirstBitOfByte(&pos)) { // terminal
-                if (addressLookingFor == (pos-1)) { // found !!
-                    depth++;
-                    word[depth] = (0xFF & DICT[pos-1]);
-                    found = true;
-                    break;
-                }
-                if (getSecondBitOfByte(&pos)) { // address + freq (4 byte)
-                    if (!followDownAddressSearchStop) {
-                        int addr = getBigramAddress(&pos, false);
-                        if (addr > addressLookingFor) {
-                            followDownAddressSearchStop = true;
-                            if (firstAddress) {
-                                firstAddress = false;
-                                haveToSearchAll = true;
-                            } else if (!haveToSearchAll) {
-                                break;
-                            }
-                        } else {
-                            followDownBranchAddress = addr;
-                            followingChar = (char)(0xFF & DICT[pos-1]);
-                            if (firstAddress) {
-                                firstAddress = false;
-                                haveToSearchAll = true;
-                            }
-                        }
-                    }
-                    pos += 4;
-                } else { // freq only (2 byte)
-                    pos += 2;
-                }
-
-                // skipping bigram
-                int bigramExist = (DICT[pos] & FLAG_BIGRAM_READ);
-                if (bigramExist > 0) {
-                    int nextBigramExist = 1;
-                    while (nextBigramExist > 0) {
-                        pos += 3;
-                        nextBigramExist = (DICT[pos++] & FLAG_BIGRAM_CONTINUED);
-                    }
-                } else {
-                    pos++;
-                }
-            }
-        }
-        depth++;
-        if (followDownBranchAddress == 0) {
-            if (DEBUG_DICT) {
-                LOGI("ERROR!!! Cannot find bigram!!");
-            }
-            break;
-        }
-    }
-    if (checkFirstCharacter(word)) {
-        addWordBigram(word, depth, frequency);
-    }
+        ++bigramCount;
+    } while (0 != (UnigramDictionary::FLAG_ATTRIBUTE_HAS_NEXT & bigramFlags));
+    return bigramCount;
 }
 
 bool BigramDictionary::checkFirstCharacter(unsigned short *word) {
diff --git a/native/src/binary_format.h b/native/src/binary_format.h
index a946b1e..6f65088 100644
--- a/native/src/binary_format.h
+++ b/native/src/binary_format.h
@@ -50,6 +50,8 @@
             int *pos);
     static int getTerminalPosition(const uint8_t* const root, const uint16_t* const inWord,
             const int length);
+    static int getWordAtAddress(const uint8_t* const root, const int address, const int maxDepth,
+            uint16_t* outWord);
 };
 
 inline int BinaryFormat::detectFormat(const uint8_t* const dict) {
@@ -290,6 +292,151 @@
     }
 }
 
+// This function searches for a terminal in the dictionary by its address.
+// Due to the fact that words are ordered in the dictionary in a strict breadth-first order,
+// it is possible to check for this with advantageous complexity. For each node, we search
+// for groups with children and compare the children address with the address we look for.
+// When we shoot the address we look for, it means the word we look for is in the children
+// of the previous group. The only tricky part is the fact that if we arrive at the end of a
+// node with the last group's children address still less than what we are searching for, we
+// must descend the last group's children (for example, if the word we are searching for starts
+// with a z, it's the last group of the root node, so all children addresses will be smaller
+// than the address we look for, and we have to descend the z node).
+/* Parameters :
+ * root: the dictionary buffer
+ * address: the byte position of the last chargroup of the word we are searching for (this is
+ *   what is stored as the "bigram address" in each bigram)
+ * outword: an array to write the found word, with MAX_WORD_LENGTH size.
+ * Return value : the length of the word, of 0 if the word was not found.
+ */
+inline int BinaryFormat::getWordAtAddress(const uint8_t* const root, const int address,
+        const int maxDepth, uint16_t* outWord) {
+    int pos = 0;
+    int wordPos = 0;
+
+    // One iteration of the outer loop iterates through nodes. As stated above, we will only
+    // traverse nodes that are actually a part of the terminal we are searching, so each time
+    // we enter this loop we are one depth level further than last time.
+    // The only reason we count nodes is because we want to reduce the probability of infinite
+    // looping in case there is a bug. Since we know there is an upper bound to the depth we are
+    // supposed to traverse, it does not hurt to count iterations.
+    for (int loopCount = maxDepth; loopCount > 0; --loopCount) {
+        int lastCandidateGroupPos = 0;
+        // Let's loop through char groups in this node searching for either the terminal
+        // or one of its ascendants.
+        for (int charGroupCount = getGroupCountAndForwardPointer(root, &pos); charGroupCount > 0;
+                 --charGroupCount) {
+            const int startPos = pos;
+            const uint8_t flags = getFlagsAndForwardPointer(root, &pos);
+            const int32_t character = getCharCodeAndForwardPointer(root, &pos);
+            if (address == startPos) {
+                // We found the address. Copy the rest of the word in the buffer and return
+                // the length.
+                outWord[wordPos] = character;
+                if (UnigramDictionary::FLAG_HAS_MULTIPLE_CHARS & flags) {
+                    int32_t nextChar = getCharCodeAndForwardPointer(root, &pos);
+                    // We count chars in order to avoid infinite loops if the file is broken or
+                    // if there is some other bug
+                    int charCount = maxDepth;
+                    while (-1 != nextChar && --charCount > 0) {
+                        outWord[++wordPos] = nextChar;
+                        nextChar = getCharCodeAndForwardPointer(root, &pos);
+                    }
+                }
+                return ++wordPos;
+            }
+            // We need to skip past this char group, so skip any remaining chars after the
+            // first and possibly the frequency.
+            if (UnigramDictionary::FLAG_HAS_MULTIPLE_CHARS & flags) {
+                pos = skipOtherCharacters(root, pos);
+            }
+            pos = skipFrequency(flags, pos);
+
+            // The fact that this group has children is very important. Since we already know
+            // that this group does not match, if it has no children we know it is irrelevant
+            // to what we are searching for.
+            const bool hasChildren = (UnigramDictionary::FLAG_GROUP_ADDRESS_TYPE_NOADDRESS !=
+                    (UnigramDictionary::MASK_GROUP_ADDRESS_TYPE & flags));
+            // We will write in `found' whether we have passed the children address we are
+            // searching for. For example if we search for "beer", the children of b are less
+            // than the address we are searching for and the children of c are greater. When we
+            // come here for c, we realize this is too big, and that we should descend b.
+            bool found;
+            if (hasChildren) {
+                // Here comes the tricky part. First, read the children position.
+                const int childrenPos = readChildrenPosition(root, flags, pos);
+                if (childrenPos > address) {
+                    // If the children pos is greater than address, it means the previous chargroup,
+                    // which address is stored in lastCandidateGroupPos, was the right one.
+                    found = true;
+                } else if (1 >= charGroupCount) {
+                    // However if we are on the LAST group of this node, and we have NOT shot the
+                    // address we should descend THIS node. So we trick the lastCandidateGroupPos
+                    // so that we will descend this node, not the previous one.
+                    lastCandidateGroupPos = startPos;
+                    found = true;
+                } else {
+                    // Else, we should continue looking.
+                    found = false;
+                }
+            } else {
+                // Even if we don't have children here, we could still be on the last group of this
+                // node. If this is the case, we should descend the last group that had children,
+                // and their address is already in lastCandidateGroup.
+                found = (1 >= charGroupCount);
+            }
+
+            if (found) {
+                // Okay, we found the group we should descend. Its address is in
+                // the lastCandidateGroupPos variable, so we just re-read it.
+                if (0 != lastCandidateGroupPos) {
+                    const uint8_t lastFlags =
+                            getFlagsAndForwardPointer(root, &lastCandidateGroupPos);
+                    const int32_t lastChar =
+                            getCharCodeAndForwardPointer(root, &lastCandidateGroupPos);
+                    // We copy all the characters in this group to the buffer
+                    outWord[wordPos] = lastChar;
+                    if (UnigramDictionary::FLAG_HAS_MULTIPLE_CHARS & lastFlags) {
+                        int32_t nextChar =
+                                getCharCodeAndForwardPointer(root, &lastCandidateGroupPos);
+                        int charCount = maxDepth;
+                        while (-1 != nextChar && --charCount > 0) {
+                            outWord[++wordPos] = nextChar;
+                            nextChar = getCharCodeAndForwardPointer(root, &lastCandidateGroupPos);
+                        }
+                    }
+                    ++wordPos;
+                    // Now we only need to branch to the children address. Skip the frequency if
+                    // it's there, read pos, and break to resume the search at pos.
+                    lastCandidateGroupPos = skipFrequency(lastFlags, lastCandidateGroupPos);
+                    pos = readChildrenPosition(root, lastFlags, lastCandidateGroupPos);
+                    break;
+                } else {
+                    // Here is a little tricky part: we come here if we found out that all children
+                    // addresses in this group are bigger than the address we are searching for.
+                    // Should we conclude the word is not in the dictionary? No! It could still be
+                    // one of the remaining chargroups in this node, so we have to keep looking in
+                    // this node until we find it (or we realize it's not there either, in which
+                    // case it's actually not in the dictionary). Pass the end of this group, ready
+                    // to start the next one.
+                    pos = skipChildrenPosAndAttributes(root, flags, pos);
+                }
+            } else {
+                // If we did not find it, we should record the last children address for the next
+                // iteration.
+                if (hasChildren) lastCandidateGroupPos = startPos;
+                // Now skip the end of this group (children pos and the attributes if any) so that
+                // our pos is after the end of this char group, at the start of the next one.
+                pos = skipChildrenPosAndAttributes(root, flags, pos);
+            }
+
+        }
+    }
+    // If we have looked through all the chargroups and found no match, the address is
+    // not the address of a terminal in this dictionary.
+    return 0;
+}
+
 } // namespace latinime
 
 #endif // LATINIME_BINARY_FORMAT_H
diff --git a/native/src/dictionary.cpp b/native/src/dictionary.cpp
index 9e32ee8..a49769b 100644
--- a/native/src/dictionary.cpp
+++ b/native/src/dictionary.cpp
@@ -57,12 +57,4 @@
     return mUnigramDictionary->isValidWord(word, length);
 }
 
-int Dictionary::getBigramPosition(unsigned short *word, int length) {
-    if (IS_LATEST_DICT_VERSION) {
-        return mUnigramDictionary->getBigramPosition(DICTIONARY_HEADER_SIZE, word, 0, length);
-    } else {
-        return mUnigramDictionary->getBigramPosition(0, word, 0, length);
-    }
-}
-
 } // namespace latinime
diff --git a/native/src/dictionary.h b/native/src/dictionary.h
index 73e03d8..d5de008 100644
--- a/native/src/dictionary.h
+++ b/native/src/dictionary.h
@@ -64,8 +64,6 @@
             const int pos, unsigned short *c, int *childrenPosition,
             bool *terminal, int *freq);
     static inline unsigned short toBaseLowerCase(unsigned short c);
-    // TODO: delete this
-    int getBigramPosition(unsigned short *word, int length);
 
 private:
     bool hasBigram();