#ifndef moses_TargetNgramFeature_h #define moses_TargetNgramFeature_h #include #include #include #include "StatefulFeatureFunction.h" #include "moses/FF/FFState.h" #include "moses/Word.h" #include "moses/FactorCollection.h" #include "moses/LM/SingleFactor.h" #include "moses/ChartHypothesis.h" #include "moses/ChartManager.h" #include "util/string_stream.hh" namespace Moses { class TargetNgramState : public FFState { public: TargetNgramState() {} TargetNgramState(const std::vector &words): m_words(words) {} const std::vector GetWords() const { return m_words; } size_t hash() const; virtual bool operator==(const FFState& other) const; private: std::vector m_words; }; class TargetNgramChartState : public FFState { private: Phrase m_contextPrefix, m_contextSuffix; size_t m_numTargetTerminals; // This isn't really correct except for the surviving hypothesis size_t m_startPos, m_endPos, m_inputSize; /** Construct the prefix string of up to specified size * \param ret prefix string * \param size maximum size (typically max lm context window) */ size_t CalcPrefix(const ChartHypothesis &hypo, const int featureId, Phrase &ret, size_t size) const { const TargetPhrase &target = hypo.GetCurrTargetPhrase(); const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = target.GetAlignNonTerm().GetNonTermIndexMap(); // loop over the rule that is being applied for (size_t pos = 0; pos < target.GetSize(); ++pos) { const Word &word = target.GetWord(pos); // for non-terminals, retrieve it from underlying hypothesis if (word.IsNonTerminal()) { size_t nonTermInd = nonTermIndexMap[pos]; const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd); size = static_cast(prevHypo->GetFFState(featureId))->CalcPrefix(*prevHypo, featureId, ret, size); // Phrase phrase = static_cast(prevHypo->GetFFState(featureId))->GetPrefix(); // size = phrase.GetSize(); } // for words, add word else { ret.AddWord(word); size--; } // finish when maximum length reached if (size==0) break; } return size; } /** Construct the suffix phrase of up to specified size * will always be called after the construction of prefix phrase * \param ret suffix phrase * \param size maximum size of suffix */ size_t CalcSuffix(const ChartHypothesis &hypo, int featureId, Phrase &ret, size_t size) const { size_t prefixSize = m_contextPrefix.GetSize(); assert(prefixSize <= m_numTargetTerminals); // special handling for small hypotheses // does the prefix match the entire hypothesis string? -> just copy prefix if (prefixSize == m_numTargetTerminals) { size_t maxCount = std::min(prefixSize, size); size_t pos= prefixSize - 1; for (size_t ind = 0; ind < maxCount; ++ind) { const Word &word = m_contextPrefix.GetWord(pos); ret.PrependWord(word); --pos; } size -= maxCount; return size; } // construct suffix analogous to prefix else { const TargetPhrase targetPhrase = hypo.GetCurrTargetPhrase(); const AlignmentInfo::NonTermIndexMap &nonTermIndexMap = targetPhrase.GetAlignTerm().GetNonTermIndexMap(); for (int pos = (int) targetPhrase.GetSize() - 1; pos >= 0 ; --pos) { const Word &word = targetPhrase.GetWord(pos); if (word.IsNonTerminal()) { size_t nonTermInd = nonTermIndexMap[pos]; const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd); size = static_cast(prevHypo->GetFFState(featureId))->CalcSuffix(*prevHypo, featureId, ret, size); } else { ret.PrependWord(word); size--; } if (size==0) break; } return size; } } public: TargetNgramChartState(const ChartHypothesis &hypo, int featureId, size_t order) :m_contextPrefix(order - 1), m_contextSuffix(order - 1) { m_numTargetTerminals = hypo.GetCurrTargetPhrase().GetNumTerminals(); const Range range = hypo.GetCurrSourceRange(); m_startPos = range.GetStartPos(); m_endPos = range.GetEndPos(); m_inputSize = hypo.GetManager().GetSource().GetSize(); const std::vector prevHypos = hypo.GetPrevHypos(); for (std::vector::const_iterator i = prevHypos.begin(); i != prevHypos.end(); ++i) { // keep count of words (= length of generated string) m_numTargetTerminals += static_cast((*i)->GetFFState(featureId))->GetNumTargetTerminals(); } CalcPrefix(hypo, featureId, m_contextPrefix, order - 1); CalcSuffix(hypo, featureId, m_contextSuffix, order - 1); } size_t GetNumTargetTerminals() const { return m_numTargetTerminals; } const Phrase &GetPrefix() const { return m_contextPrefix; } const Phrase &GetSuffix() const { return m_contextSuffix; } size_t hash() const { // not sure if this is correct size_t ret; ret = m_startPos; boost::hash_combine(ret, m_endPos); boost::hash_combine(ret, m_inputSize); // prefix if (m_startPos > 0) { // not for " ..." boost::hash_combine(ret, hash_value(GetPrefix())); } if (m_endPos < m_inputSize - 1) { // not for "... " boost::hash_combine(ret, hash_value(GetSuffix())); } return ret; } virtual bool operator==(const FFState& o) const { const TargetNgramChartState &other = static_cast( o ); // prefix if (m_startPos > 0) { // not for " ..." if (GetPrefix() != other.GetPrefix()) return false; } if (m_endPos < m_inputSize - 1) { // not for "... " if (GetSuffix() != other.GetSuffix()) return false; } return true; } }; /** Sets the features of observed ngrams. */ class TargetNgramFeature : public StatefulFeatureFunction { public: TargetNgramFeature(const std::string &line); void Load(); bool IsUseable(const FactorMask &mask) const; virtual const FFState* EmptyHypothesisState(const InputType &input) const; virtual FFState* EvaluateWhenApplied(const Hypothesis& cur_hypo, const FFState* prev_state, ScoreComponentCollection* accumulator) const; virtual FFState* EvaluateWhenApplied(const ChartHypothesis& cur_hypo, int featureId, ScoreComponentCollection* accumulator) const; void SetParameter(const std::string& key, const std::string& value); private: FactorType m_factorType; Word m_bos; boost::unordered_set m_vocab; size_t m_n; bool m_lower_ngrams; std::string m_file; std::string m_baseName; void appendNgram(const Word& word, bool& skip, util::StringStream& ngram) const; void MakePrefixNgrams(std::vector &contextFactor, ScoreComponentCollection* accumulator, size_t numberOfStartPos = 1, size_t offset = 0) const; void MakeSuffixNgrams(std::vector &contextFactor, ScoreComponentCollection* accumulator, size_t numberOfEndPos = 1, size_t offset = 0) const; }; } #endif // moses_TargetNgramFeature_h