Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/moses2/LM/KENLM.cpp')
-rw-r--r--contrib/moses2/LM/KENLM.cpp601
1 files changed, 601 insertions, 0 deletions
diff --git a/contrib/moses2/LM/KENLM.cpp b/contrib/moses2/LM/KENLM.cpp
new file mode 100644
index 000000000..3173392cd
--- /dev/null
+++ b/contrib/moses2/LM/KENLM.cpp
@@ -0,0 +1,601 @@
+/*
+ * KENLM.cpp
+ *
+ * Created on: 4 Nov 2015
+ * Author: hieu
+ */
+#include <sstream>
+#include <vector>
+#include "KENLM.h"
+#include "../Phrase.h"
+#include "../Scores.h"
+#include "../System.h"
+#include "../PhraseBased/Hypothesis.h"
+#include "../PhraseBased/Manager.h"
+#include "../PhraseBased/TargetPhraseImpl.h"
+#include "lm/state.hh"
+#include "lm/left.hh"
+#include "util/exception.hh"
+#include "util/tokenize_piece.hh"
+#include "util/string_stream.hh"
+#include "../legacy/FactorCollection.h"
+#include "../SCFG/TargetPhraseImpl.h"
+#include "../SCFG/Hypothesis.h"
+#include "../SCFG/Manager.h"
+
+using namespace std;
+
+namespace Moses2
+{
+
+struct KenLMState: public FFState
+{
+ lm::ngram::State state;
+ virtual size_t hash() const
+ {
+ size_t ret = hash_value(state);
+ return ret;
+ }
+ virtual bool operator==(const FFState& o) const
+ {
+ const KenLMState &other = static_cast<const KenLMState &>(o);
+ bool ret = state == other.state;
+ return ret;
+ }
+
+ virtual std::string ToString() const
+ {
+ stringstream ss;
+ for (size_t i = 0; i < state.Length(); ++i) {
+ ss << state.words[i] << " ";
+ }
+ return ss.str();
+ }
+
+};
+
+/////////////////////////////////////////////////////////////////
+class LanguageModelChartStateKenLM : public FFState
+{
+public:
+ LanguageModelChartStateKenLM() {}
+
+ const lm::ngram::ChartState &GetChartState() const {
+ return m_state;
+ }
+ lm::ngram::ChartState &GetChartState() {
+ return m_state;
+ }
+
+ size_t hash() const {
+ size_t ret = hash_value(m_state);
+ return ret;
+ }
+ virtual bool operator==(const FFState& o) const {
+ const LanguageModelChartStateKenLM &other = static_cast<const LanguageModelChartStateKenLM &>(o);
+ bool ret = m_state == other.m_state;
+ return ret;
+ }
+
+ virtual std::string ToString() const
+ {
+ return "LanguageModelChartStateKenLM";
+ }
+
+private:
+ lm::ngram::ChartState m_state;
+};
+
+/////////////////////////////////////////////////////////////////
+class MappingBuilder: public lm::EnumerateVocab
+{
+public:
+ MappingBuilder(FactorCollection &factorCollection, System &system,
+ std::vector<lm::WordIndex> &mapping) :
+ m_factorCollection(factorCollection), m_system(system), m_mapping(mapping)
+ {
+ }
+
+ void Add(lm::WordIndex index, const StringPiece &str)
+ {
+ std::size_t factorId = m_factorCollection.AddFactor(str, m_system, false)->GetId();
+ if (m_mapping.size() <= factorId) {
+ // 0 is <unk> :-)
+ m_mapping.resize(factorId + 1);
+ }
+ m_mapping[factorId] = index;
+ }
+
+private:
+ FactorCollection &m_factorCollection;
+ std::vector<lm::WordIndex> &m_mapping;
+ System &m_system;
+};
+
+/////////////////////////////////////////////////////////////////
+template<class Model>
+KENLM<Model>::KENLM(size_t startInd, const std::string &line,
+ const std::string &file, FactorType factorType,
+ util::LoadMethod load_method) :
+ StatefulFeatureFunction(startInd, line), m_path(file), m_factorType(
+ factorType), m_load_method(load_method)
+{
+ ReadParameters();
+}
+
+template<class Model>
+KENLM<Model>::~KENLM()
+{
+ // TODO Auto-generated destructor stub
+}
+
+template<class Model>
+void KENLM<Model>::Load(System &system)
+{
+ FactorCollection &fc = system.GetVocab();
+
+ m_bos = fc.AddFactor(BOS_, system, false);
+ m_eos = fc.AddFactor(EOS_, system, false);
+
+ lm::ngram::Config config;
+ config.messages = NULL;
+
+ FactorCollection &collection = system.GetVocab();
+ MappingBuilder builder(collection, system, m_lmIdLookup);
+ config.enumerate_vocab = &builder;
+ config.load_method = m_load_method;
+
+ m_ngram.reset(new Model(m_path.c_str(), config));
+}
+
+template<class Model>
+FFState* KENLM<Model>::BlankState(MemPool &pool, const System &sys) const
+{
+ FFState *ret;
+ if (sys.isPb) {
+ ret = new (pool.Allocate<KenLMState>()) KenLMState();
+ }
+ else {
+ ret = new (pool.Allocate<LanguageModelChartStateKenLM>()) LanguageModelChartStateKenLM();
+ }
+ return ret;
+}
+
+//! return the state associated with the empty hypothesis for a given sentence
+template<class Model>
+void KENLM<Model>::EmptyHypothesisState(FFState &state, const ManagerBase &mgr,
+ const InputType &input, const Hypothesis &hypo) const
+{
+ KenLMState &stateCast = static_cast<KenLMState&>(state);
+ stateCast.state = m_ngram->BeginSentenceState();
+}
+
+template<class Model>
+void KENLM<Model>::EvaluateInIsolation(MemPool &pool, const System &system,
+ const Phrase<Moses2::Word> &source, const TargetPhraseImpl &targetPhrase, Scores &scores,
+ SCORE &estimatedScore) const
+{
+ // contains factors used by this LM
+ float fullScore, nGramScore;
+ size_t oovCount;
+
+ CalcScore(targetPhrase, fullScore, nGramScore, oovCount);
+
+ float estimateScore = fullScore - nGramScore;
+
+ bool GetLMEnableOOVFeature = false;
+ if (GetLMEnableOOVFeature) {
+ float scoresVec[2], estimateScoresVec[2];
+ scoresVec[0] = nGramScore;
+ scoresVec[1] = oovCount;
+ scores.PlusEquals(system, *this, scoresVec);
+
+ estimateScoresVec[0] = estimateScore;
+ estimateScoresVec[1] = 0;
+ SCORE weightedScore = Scores::CalcWeightedScore(system, *this,
+ estimateScoresVec);
+ estimatedScore += weightedScore;
+ }
+ else {
+ scores.PlusEquals(system, *this, nGramScore);
+
+ SCORE weightedScore = Scores::CalcWeightedScore(system, *this,
+ estimateScore);
+ estimatedScore += weightedScore;
+ }
+}
+
+template<class Model>
+void KENLM<Model>::EvaluateInIsolation(MemPool &pool, const System &system, const Phrase<SCFG::Word> &source,
+ const TargetPhrase<SCFG::Word> &targetPhrase, Scores &scores,
+ SCORE &estimatedScore) const
+{
+ // contains factors used by this LM
+ float fullScore, nGramScore;
+ size_t oovCount;
+
+ CalcScore(targetPhrase, fullScore, nGramScore, oovCount);
+
+ //float estimateScore = fullScore - nGramScore;
+
+ // all LM scores are estimated
+ float estimateScore = fullScore;
+ nGramScore = 0;
+
+ bool GetLMEnableOOVFeature = false;
+ if (GetLMEnableOOVFeature) {
+ float scoresVec[2], estimateScoresVec[2];
+ scoresVec[0] = nGramScore;
+ scoresVec[1] = oovCount;
+ scores.PlusEquals(system, *this, scoresVec);
+
+ estimateScoresVec[0] = estimateScore;
+ estimateScoresVec[1] = 0;
+ SCORE weightedScore = Scores::CalcWeightedScore(system, *this,
+ estimateScoresVec);
+ estimatedScore += weightedScore;
+ }
+ else {
+ scores.PlusEquals(system, *this, nGramScore);
+
+ SCORE weightedScore = Scores::CalcWeightedScore(system, *this,
+ estimateScore);
+ estimatedScore += weightedScore;
+ }
+}
+
+template<class Model>
+void KENLM<Model>::EvaluateWhenApplied(const ManagerBase &mgr,
+ const Hypothesis &hypo, const FFState &prevState, Scores &scores,
+ FFState &state) const
+{
+ KenLMState &stateCast = static_cast<KenLMState&>(state);
+
+ const System &system = mgr.system;
+
+ const lm::ngram::State &in_state =
+ static_cast<const KenLMState&>(prevState).state;
+
+ if (!hypo.GetTargetPhrase().GetSize()) {
+ stateCast.state = in_state;
+ return;
+ }
+
+ const std::size_t begin = hypo.GetCurrTargetWordsRange().GetStartPos();
+ //[begin, end) in STL-like fashion.
+ const std::size_t end = hypo.GetCurrTargetWordsRange().GetEndPos() + 1;
+ const std::size_t adjust_end = std::min(end, begin + m_ngram->Order() - 1);
+
+ std::size_t position = begin;
+ typename Model::State aux_state;
+ typename Model::State *state0 = &stateCast.state, *state1 = &aux_state;
+
+ float score = m_ngram->Score(in_state, TranslateID(hypo.GetWord(position)),
+ *state0);
+ ++position;
+ for (; position < adjust_end; ++position) {
+ score += m_ngram->Score(*state0, TranslateID(hypo.GetWord(position)),
+ *state1);
+ std::swap(state0, state1);
+ }
+
+ if (hypo.GetBitmap().IsComplete()) {
+ // Score end of sentence.
+ std::vector<lm::WordIndex> indices(m_ngram->Order() - 1);
+ const lm::WordIndex *last = LastIDs(hypo, &indices.front());
+ score += m_ngram->FullScoreForgotState(&indices.front(), last,
+ m_ngram->GetVocabulary().EndSentence(), stateCast.state).prob;
+ }
+ else if (adjust_end < end) {
+ // Get state after adding a long phrase.
+ std::vector<lm::WordIndex> indices(m_ngram->Order() - 1);
+ const lm::WordIndex *last = LastIDs(hypo, &indices.front());
+ m_ngram->GetState(&indices.front(), last, stateCast.state);
+ }
+ else if (state0 != &stateCast.state) {
+ // Short enough phrase that we can just reuse the state.
+ stateCast.state = *state0;
+ }
+
+ score = TransformLMScore(score);
+
+ bool OOVFeatureEnabled = false;
+ if (OOVFeatureEnabled) {
+ std::vector<float> scoresVec(2);
+ scoresVec[0] = score;
+ scoresVec[1] = 0.0;
+ scores.PlusEquals(system, *this, scoresVec);
+ }
+ else {
+ scores.PlusEquals(system, *this, score);
+ }
+}
+
+template<class Model>
+void KENLM<Model>::CalcScore(const Phrase<Moses2::Word> &phrase, float &fullScore,
+ float &ngramScore, std::size_t &oovCount) const
+{
+ fullScore = 0;
+ ngramScore = 0;
+ oovCount = 0;
+
+ if (!phrase.GetSize()) return;
+
+ lm::ngram::ChartState discarded_sadly;
+ lm::ngram::RuleScore<Model> scorer(*m_ngram, discarded_sadly);
+
+ size_t position;
+ if (m_bos == phrase[0][m_factorType]) {
+ scorer.BeginSentence();
+ position = 1;
+ }
+ else {
+ position = 0;
+ }
+
+ size_t ngramBoundary = m_ngram->Order() - 1;
+
+ size_t end_loop = std::min(ngramBoundary, phrase.GetSize());
+ for (; position < end_loop; ++position) {
+ const Word &word = phrase[position];
+ lm::WordIndex index = TranslateID(word);
+ scorer.Terminal(index);
+ if (!index) ++oovCount;
+ }
+ float before_boundary = fullScore + scorer.Finish();
+ for (; position < phrase.GetSize(); ++position) {
+ const Word &word = phrase[position];
+ lm::WordIndex index = TranslateID(word);
+ scorer.Terminal(index);
+ if (!index) ++oovCount;
+ }
+ fullScore += scorer.Finish();
+
+ ngramScore = TransformLMScore(fullScore - before_boundary);
+ fullScore = TransformLMScore(fullScore);
+}
+
+template<class Model>
+void KENLM<Model>::CalcScore(const Phrase<SCFG::Word> &phrase, float &fullScore,
+ float &ngramScore, std::size_t &oovCount) const
+{
+ fullScore = 0;
+ ngramScore = 0;
+ oovCount = 0;
+
+ if (!phrase.GetSize()) return;
+
+ lm::ngram::ChartState discarded_sadly;
+ lm::ngram::RuleScore<Model> scorer(*m_ngram, discarded_sadly);
+
+ size_t position;
+ if (m_bos == phrase[0][m_factorType]) {
+ scorer.BeginSentence();
+ position = 1;
+ } else {
+ position = 0;
+ }
+
+ size_t ngramBoundary = m_ngram->Order() - 1;
+
+ size_t end_loop = std::min(ngramBoundary, phrase.GetSize());
+ for (; position < end_loop; ++position) {
+ const SCFG::Word &word = phrase[position];
+ if (word.isNonTerminal) {
+ fullScore += scorer.Finish();
+ scorer.Reset();
+ } else {
+ lm::WordIndex index = TranslateID(word);
+ scorer.Terminal(index);
+ if (!index) ++oovCount;
+ }
+ }
+ float before_boundary = fullScore + scorer.Finish();
+ for (; position < phrase.GetSize(); ++position) {
+ const SCFG::Word &word = phrase[position];
+ if (word.isNonTerminal) {
+ fullScore += scorer.Finish();
+ scorer.Reset();
+ } else {
+ lm::WordIndex index = TranslateID(word);
+ scorer.Terminal(index);
+ if (!index) ++oovCount;
+ }
+ }
+ fullScore += scorer.Finish();
+
+ ngramScore = TransformLMScore(fullScore - before_boundary);
+ fullScore = TransformLMScore(fullScore);
+}
+
+// Convert last words of hypothesis into vocab ids, returning an end pointer.
+template<class Model>
+lm::WordIndex *KENLM<Model>::LastIDs(const Hypothesis &hypo,
+ lm::WordIndex *indices) const
+{
+ lm::WordIndex *index = indices;
+ lm::WordIndex *end = indices + m_ngram->Order() - 1;
+ int position = hypo.GetCurrTargetWordsRange().GetEndPos();
+ for (;; ++index, --position) {
+ if (index == end) return index;
+ if (position == -1) {
+ *index = m_ngram->GetVocabulary().BeginSentence();
+ return index + 1;
+ }
+ *index = TranslateID(hypo.GetWord(position));
+ }
+}
+
+template<class Model>
+void KENLM<Model>::EvaluateWhenApplied(const SCFG::Manager &mgr,
+ const SCFG::Hypothesis &hypo, int featureID, Scores &scores,
+ FFState &state) const
+{
+ LanguageModelChartStateKenLM &newState = static_cast<LanguageModelChartStateKenLM&>(state);
+ lm::ngram::RuleScore<Model> ruleScore(*m_ngram, newState.GetChartState());
+ const SCFG::TargetPhraseImpl &target = hypo.GetTargetPhrase();
+ const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
+ target.GetAlignNonTerm().GetNonTermIndexMap();
+
+ const size_t size = target.GetSize();
+ size_t phrasePos = 0;
+ // Special cases for first word.
+ if (size) {
+ const SCFG::Word &word = target[0];
+ if (word[m_factorType] == m_bos) {
+ // Begin of sentence
+ ruleScore.BeginSentence();
+ phrasePos++;
+ } else if (word.isNonTerminal) {
+ // Non-terminal is first so we can copy instead of rescoring.
+ const SCFG::Hypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]);
+ const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(prevHypo->GetState(featureID))->GetChartState();
+ ruleScore.BeginNonTerminal(prevState);
+ phrasePos++;
+ }
+ }
+
+ for (; phrasePos < size; phrasePos++) {
+ const SCFG::Word &word = target[phrasePos];
+ if (word.isNonTerminal) {
+ const SCFG::Hypothesis *prevHypo = hypo.GetPrevHypo(nonTermIndexMap[phrasePos]);
+ const lm::ngram::ChartState &prevState = static_cast<const LanguageModelChartStateKenLM*>(prevHypo->GetState(featureID))->GetChartState();
+ ruleScore.NonTerminal(prevState);
+ } else {
+ ruleScore.Terminal(TranslateID(word));
+ }
+ }
+
+ float score = ruleScore.Finish();
+ score = TransformLMScore(score);
+
+ // take out score from loading. This needs reworking
+ //score -= target.GetScores().GetScores(*this)[0];
+
+ bool OOVFeatureEnabled = false;
+ if (OOVFeatureEnabled) {
+ std::vector<float> scoresVec(2);
+ scoresVec[0] = score;
+ scoresVec[1] = 0.0;
+ scores.PlusEquals(mgr.system, *this, scoresVec);
+ } else {
+ scores.PlusEquals(mgr.system, *this, score);
+ }
+}
+
+///////////////////////////////////////////////////////////////////////////
+
+/* Instantiate LanguageModelKen here. Tells the compiler to generate code
+ * for the instantiations' non-inline member functions in this file.
+ * Otherwise, depending on the compiler, those functions may not be present
+ * at link time.
+ */
+template class KENLM<lm::ngram::ProbingModel> ;
+template class KENLM<lm::ngram::RestProbingModel> ;
+template class KENLM<lm::ngram::TrieModel> ;
+template class KENLM<lm::ngram::ArrayTrieModel> ;
+template class KENLM<lm::ngram::QuantTrieModel> ;
+template class KENLM<lm::ngram::QuantArrayTrieModel> ;
+
+FeatureFunction *ConstructKenLM(size_t startInd, const std::string &lineOrig)
+{
+ FactorType factorType = 0;
+ string filePath;
+ util::LoadMethod load_method = util::POPULATE_OR_READ;
+
+ util::TokenIter<util::SingleCharacter, true> argument(lineOrig, ' ');
+ ++argument; // KENLM
+
+ util::StringStream line;
+ line << "KENLM";
+
+ for (; argument; ++argument) {
+ const char *equals = std::find(argument->data(),
+ argument->data() + argument->size(), '=');
+ UTIL_THROW_IF2(equals == argument->data() + argument->size(),
+ "Expected = in KenLM argument " << *argument);
+ StringPiece name(argument->data(), equals - argument->data());
+ StringPiece value(equals + 1,
+ argument->data() + argument->size() - equals - 1);
+ if (name == "factor") {
+ factorType = boost::lexical_cast<FactorType>(value);
+ }
+ else if (name == "order") {
+ // Ignored
+ }
+ else if (name == "path") {
+ filePath.assign(value.data(), value.size());
+ }
+ else if (name == "lazyken") {
+ // deprecated: use load instead.
+ load_method =
+ boost::lexical_cast<bool>(value) ?
+ util::LAZY : util::POPULATE_OR_READ;
+ }
+ else if (name == "load") {
+ if (value == "lazy") {
+ load_method = util::LAZY;
+ }
+ else if (value == "populate_or_lazy") {
+ load_method = util::POPULATE_OR_LAZY;
+ }
+ else if (value == "populate_or_read" || value == "populate") {
+ load_method = util::POPULATE_OR_READ;
+ }
+ else if (value == "read") {
+ load_method = util::READ;
+ }
+ else if (value == "parallel_read") {
+ load_method = util::PARALLEL_READ;
+ }
+ else {
+ UTIL_THROW2("Unknown KenLM load method " << value);
+ }
+ }
+ else {
+ // pass to base class to interpret
+ line << " " << name << "=" << value;
+ }
+ }
+
+ return ConstructKenLM(startInd, line.str(), filePath, factorType, load_method);
+}
+
+FeatureFunction *ConstructKenLM(size_t startInd, const std::string &line,
+ const std::string &file, FactorType factorType,
+ util::LoadMethod load_method)
+{
+ lm::ngram::ModelType model_type;
+ if (lm::ngram::RecognizeBinary(file.c_str(), model_type)) {
+ switch (model_type) {
+ case lm::ngram::PROBING:
+ return new KENLM<lm::ngram::ProbingModel>(startInd, line, file,
+ factorType, load_method);
+ case lm::ngram::REST_PROBING:
+ return new KENLM<lm::ngram::RestProbingModel>(startInd, line, file,
+ factorType, load_method);
+ case lm::ngram::TRIE:
+ return new KENLM<lm::ngram::TrieModel>(startInd, line, file, factorType,
+ load_method);
+ case lm::ngram::QUANT_TRIE:
+ return new KENLM<lm::ngram::QuantTrieModel>(startInd, line, file,
+ factorType, load_method);
+ case lm::ngram::ARRAY_TRIE:
+ return new KENLM<lm::ngram::ArrayTrieModel>(startInd, line, file,
+ factorType, load_method);
+ case lm::ngram::QUANT_ARRAY_TRIE:
+ return new KENLM<lm::ngram::QuantArrayTrieModel>(startInd, line, file,
+ factorType, load_method);
+ default:
+ UTIL_THROW2("Unrecognized kenlm model type " << model_type)
+ ;
+ }
+ }
+ else {
+ return new KENLM<lm::ngram::ProbingModel>(startInd, line, file, factorType,
+ load_method);
+ }
+}
+
+}
+