Use same language weight for all dictionaries.
Bug: 8187060
Change-Id: Ib9d8a8aed2c141137c1bb3c748a89fb8216293e7
diff --git a/native/jni/src/defines.h b/native/jni/src/defines.h
index 3651cd5..6c54305 100644
--- a/native/jni/src/defines.h
+++ b/native/jni/src/defines.h
@@ -309,6 +309,7 @@
#define NOT_A_PROBABILITY (-1)
#define NOT_A_DICT_POS (S_INT_MIN)
#define NOT_A_TIMESTAMP (-1)
+#define NOT_A_LANGUAGE_WEIGHT (-1.0f)
// A special value to mean the first word confidence makes no sense in this case,
// e.g. this is not a multi-word suggestion.
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.cpp b/native/jni/src/suggest/core/dictionary/dictionary.cpp
index ae4646d..ef7a0a8 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.cpp
+++ b/native/jni/src/suggest/core/dictionary/dictionary.cpp
@@ -47,7 +47,7 @@
void Dictionary::getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession,
int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints,
int inputSize, int *prevWordCodePoints, int prevWordLength,
- const SuggestOptions *const suggestOptions,
+ const SuggestOptions *const suggestOptions, const float languageWeight,
SuggestionResults *const outSuggestionResults) const {
TimeKeeper::setCurrentTime();
DicTraverseSession::initSessionInstance(
@@ -55,11 +55,11 @@
if (suggestOptions->isGesture()) {
mGestureSuggest->getSuggestions(proximityInfo, traverseSession, xcoordinates,
ycoordinates, times, pointerIds, inputCodePoints, inputSize,
- outSuggestionResults);
+ languageWeight, outSuggestionResults);
} else {
mTypingSuggest->getSuggestions(proximityInfo, traverseSession, xcoordinates,
ycoordinates, times, pointerIds, inputCodePoints, inputSize,
- outSuggestionResults);
+ languageWeight, outSuggestionResults);
}
if (DEBUG_DICT) {
outSuggestionResults->dumpSuggestions();
diff --git a/native/jni/src/suggest/core/dictionary/dictionary.h b/native/jni/src/suggest/core/dictionary/dictionary.h
index df5fc9b..cd983b0 100644
--- a/native/jni/src/suggest/core/dictionary/dictionary.h
+++ b/native/jni/src/suggest/core/dictionary/dictionary.h
@@ -65,7 +65,7 @@
void getSuggestions(ProximityInfo *proximityInfo, DicTraverseSession *traverseSession,
int *xcoordinates, int *ycoordinates, int *times, int *pointerIds, int *inputCodePoints,
int inputSize, int *prevWordCodePoints, int prevWordLength,
- const SuggestOptions *const suggestOptions,
+ const SuggestOptions *const suggestOptions, const float languageWeight,
SuggestionResults *const outSuggestionResults) const;
void getPredictions(const int *word, int length,
diff --git a/native/jni/src/suggest/core/result/suggestion_results.cpp b/native/jni/src/suggest/core/result/suggestion_results.cpp
index da1c6bc..088a55f 100644
--- a/native/jni/src/suggest/core/result/suggestion_results.cpp
+++ b/native/jni/src/suggest/core/result/suggestion_results.cpp
@@ -20,7 +20,8 @@
void SuggestionResults::outputSuggestions(JNIEnv *env, jintArray outSuggestionCount,
jintArray outputCodePointsArray, jintArray outScoresArray, jintArray outSpaceIndicesArray,
- jintArray outTypesArray, jintArray outAutoCommitFirstWordConfidenceArray) {
+ jintArray outTypesArray, jintArray outAutoCommitFirstWordConfidenceArray,
+ jfloatArray outLanguageWeight) {
int outputIndex = 0;
while (!mSuggestedWords.empty()) {
const SuggestedWord &suggestedWord = mSuggestedWords.top();
@@ -50,6 +51,7 @@
mSuggestedWords.pop();
}
env->SetIntArrayRegion(outSuggestionCount, 0 /* start */, 1 /* len */, &outputIndex);
+ env->SetFloatArrayRegion(outLanguageWeight, 0 /* start */, 1 /* len */, &mLanguageWeight);
}
void SuggestionResults::addPrediction(const int *const codePoints, const int codePointCount,
@@ -94,6 +96,7 @@
}
void SuggestionResults::dumpSuggestions() const {
+ AKLOGE("language weight: %f", mLanguageWeight);
std::vector<SuggestedWord> suggestedWords;
auto copyOfSuggestedWords = mSuggestedWords;
while (!copyOfSuggestedWords.empty()) {
diff --git a/native/jni/src/suggest/core/result/suggestion_results.h b/native/jni/src/suggest/core/result/suggestion_results.h
index 020bab4..8e845e2 100644
--- a/native/jni/src/suggest/core/result/suggestion_results.h
+++ b/native/jni/src/suggest/core/result/suggestion_results.h
@@ -29,12 +29,13 @@
class SuggestionResults {
public:
explicit SuggestionResults(const int maxSuggestionCount)
- : mMaxSuggestionCount(maxSuggestionCount), mSuggestedWords() {}
+ : mMaxSuggestionCount(maxSuggestionCount), mLanguageWeight(NOT_A_LANGUAGE_WEIGHT),
+ mSuggestedWords() {}
// Returns suggestion count.
void outputSuggestions(JNIEnv *env, jintArray outSuggestionCount, jintArray outCodePointsArray,
jintArray outScoresArray, jintArray outSpaceIndicesArray, jintArray outTypesArray,
- jintArray outAutoCommitFirstWordConfidenceArray);
+ jintArray outAutoCommitFirstWordConfidenceArray, jfloatArray outLanguageWeight);
void addPrediction(const int *const codePoints, const int codePointCount, const int score);
void addSuggestion(const int *const codePoints, const int codePointCount,
const int score, const int type, const int indexToPartialCommit,
@@ -42,6 +43,10 @@
void getSortedScores(int *const outScores) const;
void dumpSuggestions() const;
+ void setLanguageWeight(const float languageWeight) {
+ mLanguageWeight = languageWeight;
+ }
+
int getSuggestionCount() const {
return mSuggestedWords.size();
}
@@ -50,6 +55,7 @@
DISALLOW_IMPLICIT_CONSTRUCTORS(SuggestionResults);
const int mMaxSuggestionCount;
+ float mLanguageWeight;
std::priority_queue<
SuggestedWord, std::vector<SuggestedWord>, SuggestedWord::Comparator> mSuggestedWords;
};
diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp
index 83140f1..a307cb4 100644
--- a/native/jni/src/suggest/core/result/suggestions_output_utils.cpp
+++ b/native/jni/src/suggest/core/result/suggestions_output_utils.cpp
@@ -33,7 +33,7 @@
/* static */ void SuggestionsOutputUtils::outputSuggestions(
const Scoring *const scoringPolicy, DicTraverseSession *traverseSession,
- SuggestionResults *const outSuggestionResults) {
+ const float languageWeight, SuggestionResults *const outSuggestionResults) {
#if DEBUG_EVALUATE_MOST_PROBABLE_STRING
const int terminalSize = 0;
#else
@@ -43,9 +43,12 @@
for (int index = terminalSize - 1; index >= 0; --index) {
traverseSession->getDicTraverseCache()->popTerminal(&terminals[index]);
}
-
- const float languageWeight = scoringPolicy->getAdjustedLanguageWeight(
- traverseSession, terminals.data(), terminalSize);
+ // Compute a language weight when an invalid language weight is passed.
+ // NOT_A_LANGUAGE_WEIGHT (-1) is assumed as an invalid language weight.
+ const float languageWeightToOutputSuggestions = (languageWeight < 0.0f) ?
+ scoringPolicy->getAdjustedLanguageWeight(
+ traverseSession, terminals.data(), terminalSize) : languageWeight;
+ outSuggestionResults->setLanguageWeight(languageWeightToOutputSuggestions);
// Force autocorrection for obvious long multi-word suggestions when the top suggestion is
// a long multiple words suggestion.
// TODO: Implement a smarter auto-commit method for handling multi-word suggestions.
@@ -61,10 +64,11 @@
// Output suggestion results here
for (auto &terminalDicNode : terminals) {
outputSuggestionsOfDicNode(scoringPolicy, traverseSession, &terminalDicNode,
- languageWeight, boostExactMatches, forceCommitMultiWords,
+ languageWeightToOutputSuggestions, boostExactMatches, forceCommitMultiWords,
outputSecondWordFirstLetterInputIndex, outSuggestionResults);
}
- scoringPolicy->getMostProbableString(traverseSession, languageWeight, outSuggestionResults);
+ scoringPolicy->getMostProbableString(traverseSession, languageWeightToOutputSuggestions,
+ outSuggestionResults);
}
/* static */ void SuggestionsOutputUtils::outputSuggestionsOfDicNode(
diff --git a/native/jni/src/suggest/core/result/suggestions_output_utils.h b/native/jni/src/suggest/core/result/suggestions_output_utils.h
index 73cdb95..b099b47 100644
--- a/native/jni/src/suggest/core/result/suggestions_output_utils.h
+++ b/native/jni/src/suggest/core/result/suggestions_output_utils.h
@@ -33,7 +33,8 @@
* Outputs the final list of suggestions (i.e., terminal nodes).
*/
static void outputSuggestions(const Scoring *const scoringPolicy,
- DicTraverseSession *traverseSession, SuggestionResults *const outSuggestionResults);
+ DicTraverseSession *traverseSession, const float languageWeight,
+ SuggestionResults *const outSuggestionResults);
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(SuggestionsOutputUtils);
diff --git a/native/jni/src/suggest/core/suggest.cpp b/native/jni/src/suggest/core/suggest.cpp
index 303182c..433820a 100644
--- a/native/jni/src/suggest/core/suggest.cpp
+++ b/native/jni/src/suggest/core/suggest.cpp
@@ -44,7 +44,8 @@
*/
void Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession,
int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints,
- int inputSize, SuggestionResults *const outSuggestionResults) const {
+ int inputSize, const float languageWeight,
+ SuggestionResults *const outSuggestionResults) const {
PROF_OPEN;
PROF_START(0);
const float maxSpatialDistance = TRAVERSAL->getMaxSpatialDistance();
@@ -65,7 +66,8 @@
}
PROF_END(1);
PROF_START(2);
- SuggestionsOutputUtils::outputSuggestions(SCORING, tSession, outSuggestionResults);
+ SuggestionsOutputUtils::outputSuggestions(
+ SCORING, tSession, languageWeight, outSuggestionResults);
PROF_END(2);
PROF_CLOSE;
}
diff --git a/native/jni/src/suggest/core/suggest.h b/native/jni/src/suggest/core/suggest.h
index 13ad621..788e031 100644
--- a/native/jni/src/suggest/core/suggest.h
+++ b/native/jni/src/suggest/core/suggest.h
@@ -49,7 +49,7 @@
AK_FORCE_INLINE virtual ~Suggest() {}
void getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs, int *inputYs,
int *times, int *pointerIds, int *inputCodePoints, int inputSize,
- SuggestionResults *const outSuggestionResults) const;
+ const float languageWeight, SuggestionResults *const outSuggestionResults) const;
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(Suggest);
diff --git a/native/jni/src/suggest/core/suggest_interface.h b/native/jni/src/suggest/core/suggest_interface.h
index c3ffea9..a6e5aef 100644
--- a/native/jni/src/suggest/core/suggest_interface.h
+++ b/native/jni/src/suggest/core/suggest_interface.h
@@ -28,7 +28,7 @@
public:
virtual void getSuggestions(ProximityInfo *pInfo, void *traverseSession, int *inputXs,
int *inputYs, int *times, int *pointerIds, int *inputCodePoints, int inputSize,
- SuggestionResults *const suggestionResults) const = 0;
+ const float languageWeight, SuggestionResults *const suggestionResults) const = 0;
SuggestInterface() {}
virtual ~SuggestInterface() {}
private: