Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/moses
diff options
context:
space:
mode:
authorheafield <heafield@1f5c12ca-751b-0410-a591-d2e778427230>2011-10-12 17:04:12 +0400
committerheafield <heafield@1f5c12ca-751b-0410-a591-d2e778427230>2011-10-12 17:04:12 +0400
commitcd19f148269a2d1238d32f7b732a13e9d9facb96 (patch)
tree6e5a4591be216304842260bc8401b58a955ba5e4 /moses
parent81acd0ffa274468af982b805d7aca05548b87040 (diff)
Faster CalcScore implementation for KenLM
git-svn-id: https://mosesdecoder.svn.sourceforge.net/svnroot/mosesdecoder/trunk@4339 1f5c12ca-751b-0410-a591-d2e778427230
Diffstat (limited to 'moses')
-rw-r--r--moses/src/LanguageModelImplementation.h2
-rw-r--r--moses/src/LanguageModelKen.cpp46
2 files changed, 43 insertions, 5 deletions
diff --git a/moses/src/LanguageModelImplementation.h b/moses/src/LanguageModelImplementation.h
index 89917fcdd..b9b4c57ad 100644
--- a/moses/src/LanguageModelImplementation.h
+++ b/moses/src/LanguageModelImplementation.h
@@ -99,7 +99,7 @@ public:
virtual const FFState *GetBeginSentenceState() const = 0;
virtual FFState *NewState(const FFState *from = NULL) const = 0;
- void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const;
+ virtual void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const;
FFState *Evaluate(const Hypothesis &hypo, const FFState *ps, ScoreComponentCollection *out, const LanguageModel *feature) const;
diff --git a/moses/src/LanguageModelKen.cpp b/moses/src/LanguageModelKen.cpp
index 354530431..463aae694 100644
--- a/moses/src/LanguageModelKen.cpp
+++ b/moses/src/LanguageModelKen.cpp
@@ -47,15 +47,14 @@ LanguageModelKenBase::~LanguageModelKenBase() {}
namespace
{
- class LanguageModelChartStateKenLM : public FFState
- {
+class LanguageModelChartStateKenLM : public FFState {
private:
lm::ngram::ChartState m_state;
const ChartHypothesis *m_hypo;
public:
explicit LanguageModelChartStateKenLM(const ChartHypothesis &hypo)
- :m_hypo(&hypo)
+ :m_hypo(&hypo)
{}
const ChartHypothesis* GetHypothesis() const { return m_hypo; }
@@ -69,7 +68,7 @@ namespace
int ret = m_state.Compare(other.m_state);
return ret;
}
- };
+};
class MappingBuilder : public lm::EnumerateVocab
@@ -152,6 +151,8 @@ public:
int featureID,
ScoreComponentCollection *accumulator,
const LanguageModel *feature) const;
+
+ void CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const;
};
template <class Model>
@@ -200,6 +201,43 @@ FFState *LanguageModelKen<Model>::EvaluateChart(
return newState;
}
+template <class Model> void LanguageModelKen<Model>::CalcScore(const Phrase &phrase, float &fullScore, float &ngramScore, size_t &oovCount) const {
+ fullScore = 0;
+ ngramScore = 0;
+ oovCount = 0;
+
+ if (!phrase.GetSize()) return;
+
+ typename Model::State state_backing[2];
+ typename Model::State *state0 = &state_backing[0], *state1 = &state_backing[1];
+ size_t position;
+ if (phrase.GetWord(0) == GetSentenceStartArray()) {
+ *state0 = m_ngram->BeginSentenceState();
+ position = 1;
+ } else {
+ *state0 = m_ngram->NullContextState();
+ position = 0;
+ }
+
+ FactorType factorType = GetFactorType();
+ size_t ngramBoundary = m_ngram->Order() - 1;
+
+ for (; position < phrase.GetSize(); ++position) {
+ const Word &word = phrase.GetWord(position);
+ if (word.IsNonTerminal()) {
+ *state0 = m_ngram->NullContextState();
+ } else {
+ std::size_t factor = word.GetFactor(factorType)->GetId();
+ lm::WordIndex index = factor >= m_lmIdLookup.size() ? 0 : m_lmIdLookup[factor];
+ float score = TransformLMScore(m_ngram->Score(*state0, index, *state1));
+ std::swap(state0, state1);
+ if (position >= ngramBoundary) ngramScore += score;
+ fullScore += score;
+ if (!index) ++oovCount;
+ }
+ }
+}
+
template <class Model> void LanguageModelKen<Model>::TranslateIDs(const std::vector<const Word*> &contextFactor, lm::WordIndex *indices) const
{
FactorType factorType = GetFactorType();