Support Beginning-of-Sentence in native code
Bug: 14119293
Change-Id: I0f382e33a19bf481823b23405d454de61ec835ff
diff --git a/native/jni/src/suggest/core/session/prev_words_info.h b/native/jni/src/suggest/core/session/prev_words_info.h
index e4de1f4..a58000a 100644
--- a/native/jni/src/suggest/core/session/prev_words_info.h
+++ b/native/jni/src/suggest/core/session/prev_words_info.h
@@ -20,11 +20,11 @@
#include "defines.h"
#include "suggest/core/dictionary/binary_dictionary_bigrams_iterator.h"
#include "suggest/core/policy/dictionary_structure_with_buffer_policy.h"
+#include "utils/char_utils.h"
namespace latinime {
// TODO: Support n-gram.
-// TODO: Support beginning of sentence.
// This class does not take ownership of any code point buffers.
class PrevWordsInfo {
public:
@@ -52,8 +52,7 @@
void getPrevWordsTerminalPtNodePos(
const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
- int *const outPrevWordsTerminalPtNodePos,
- const bool tryLowerCaseSearch) const {
+ int *const outPrevWordsTerminalPtNodePos, const bool tryLowerCaseSearch) const {
for (size_t i = 0; i < NELEMS(mPrevWordCodePoints); ++i) {
outPrevWordsTerminalPtNodePos[i] = getTerminalPtNodePosOfWord(dictStructurePolicy,
mPrevWordCodePoints[i], mPrevWordCodePointCount[i],
@@ -63,17 +62,11 @@
BinaryDictionaryBigramsIterator getBigramsIteratorForPrediction(
const DictionaryStructureWithBufferPolicy *const dictStructurePolicy) const {
- int pos = getBigramListPositionForWord(dictStructurePolicy, mPrevWordCodePoints[0],
- mPrevWordCodePointCount[0], false /* forceLowerCaseSearch */);
- // getBigramListPositionForWord returns NOT_A_DICT_POS if this word isn't in the
- // dictionary or has no bigrams
- if (NOT_A_DICT_POS == pos) {
- // If no bigrams for this exact word, search again in lower case.
- pos = getBigramListPositionForWord(dictStructurePolicy, mPrevWordCodePoints[0],
- mPrevWordCodePointCount[0], true /* forceLowerCaseSearch */);
- }
- return BinaryDictionaryBigramsIterator(
- dictStructurePolicy->getBigramsStructurePolicy(), pos);
+ const int bigramListPos = getBigramListPositionForWordWithTryingLowerCaseSearch(
+ dictStructurePolicy, mPrevWordCodePoints[0], mPrevWordCodePointCount[0],
+ mIsBeginningOfSentence[0]);
+ return BinaryDictionaryBigramsIterator(dictStructurePolicy->getBigramsStructurePolicy(),
+ bigramListPos);
}
// n is 1-indexed.
@@ -102,8 +95,18 @@
if (!dictStructurePolicy || !wordCodePoints) {
return NOT_A_DICT_POS;
}
+ int codePoints[MAX_WORD_LENGTH];
+ int codePointCount = wordCodePointCount;
+ memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount);
+ if (isBeginningOfSentence) {
+ codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints,
+ codePointCount, MAX_WORD_LENGTH);
+ if (codePointCount <= 0) {
+ return NOT_A_DICT_POS;
+ }
+ }
const int wordPtNodePos = dictStructurePolicy->getTerminalPtNodePositionOfWord(
- wordCodePoints, wordCodePointCount, false /* forceLowerCaseSearch */);
+ codePoints, codePointCount, false /* forceLowerCaseSearch */);
if (wordPtNodePos != NOT_A_DICT_POS || !tryLowerCaseSearch) {
// Return the position when when the word was found or doesn't try lower case
// search.
@@ -112,7 +115,33 @@
// Check bigrams for lower-cased previous word if original was not found. Useful for
// auto-capitalized words like "The [current_word]".
return dictStructurePolicy->getTerminalPtNodePositionOfWord(
- wordCodePoints, wordCodePointCount, true /* forceLowerCaseSearch */);
+ codePoints, codePointCount, true /* forceLowerCaseSearch */);
+ }
+
+ static int getBigramListPositionForWordWithTryingLowerCaseSearch(
+ const DictionaryStructureWithBufferPolicy *const dictStructurePolicy,
+ const int *const wordCodePoints, const int wordCodePointCount,
+ const bool isBeginningOfSentence) {
+ int codePoints[MAX_WORD_LENGTH];
+ int codePointCount = wordCodePointCount;
+ memmove(codePoints, wordCodePoints, sizeof(int) * codePointCount);
+ if (isBeginningOfSentence) {
+ codePointCount = CharUtils::attachBeginningOfSentenceMarker(codePoints,
+ codePointCount, MAX_WORD_LENGTH);
+ if (codePointCount <= 0) {
+ return NOT_A_DICT_POS;
+ }
+ }
+ int pos = getBigramListPositionForWord(dictStructurePolicy, codePoints,
+ codePointCount, false /* forceLowerCaseSearch */);
+ // getBigramListPositionForWord returns NOT_A_DICT_POS if this word isn't in the
+ // dictionary or has no bigrams
+ if (NOT_A_DICT_POS == pos) {
+ // If no bigrams for this exact word, search again in lower case.
+ pos = getBigramListPositionForWord(dictStructurePolicy, codePoints,
+ codePointCount, true /* forceLowerCaseSearch */);
+ }
+ return pos;
}
static int getBigramListPositionForWord(
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
index 1858441..0247870 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp
@@ -181,9 +181,19 @@
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(getRootPosition());
bool addedNewUnigram = false;
- if (mUpdatingHelper.addUnigramWord(&readingHelper, word, length,
+ int codePointsToAdd[MAX_WORD_LENGTH];
+ int codePointCountToAdd = length;
+ memmove(codePointsToAdd, word, sizeof(int) * length);
+ if (unigramProperty->representsBeginningOfSentence()) {
+ codePointCountToAdd = CharUtils::attachBeginningOfSentenceMarker(codePointsToAdd,
+ codePointCountToAdd, MAX_WORD_LENGTH);
+ }
+ if (codePointCountToAdd <= 0) {
+ return false;
+ }
+ if (mUpdatingHelper.addUnigramWord(&readingHelper, codePointsToAdd, codePointCountToAdd,
unigramProperty, &addedNewUnigram)) {
- if (addedNewUnigram) {
+ if (addedNewUnigram && !unigramProperty->representsBeginningOfSentence()) {
mUnigramCount++;
}
if (unigramProperty->getShortcuts().size() > 0) {
diff --git a/native/jni/src/utils/char_utils.h b/native/jni/src/utils/char_utils.h
index 634c45b..f28ed56 100644
--- a/native/jni/src/utils/char_utils.h
+++ b/native/jni/src/utils/char_utils.h
@@ -18,6 +18,7 @@
#define LATINIME_CHAR_UTILS_H
#include <cctype>
+#include <cstring>
#include <vector>
#include "defines.h"
@@ -93,6 +94,19 @@
static unsigned short latin_tolower(const unsigned short c);
static const std::vector<int> EMPTY_STRING;
+ // Returns updated code point count. Returns 0 when the code points cannot be marked as a
+ // Beginning-of-Sentence.
+ static AK_FORCE_INLINE int attachBeginningOfSentenceMarker(int *const codePoints,
+ const int codePointCount, const int maxCodePoint) {
+ if (codePointCount >= maxCodePoint) {
+ // the code points cannot be marked as a Beginning-of-Sentence.
+ return 0;
+ }
+ memmove(codePoints + 1, codePoints, sizeof(int) * codePointCount);
+ codePoints[0] = CODE_POINT_BEGINNING_OF_SENTENCE;
+ return codePointCount + 1;
+ }
+
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(CharUtils);