Welcome to mirror list, hosted at ThFree Co, Russian Federation.

TargetNgramFeature.h « src « moses - github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 681e7d6aae960d1c7a1c5633caab740b20180f95 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#ifndef moses_TargetNgramFeature_h
#define moses_TargetNgramFeature_h

#include <string>
#include <map>

#include "FactorCollection.h"
#include "FeatureFunction.h"
#include "FFState.h"
#include "Word.h"

#include "LM/SingleFactor.h"
#include "ChartHypothesis.h"
#include "ChartManager.h"

namespace Moses
{

class TargetNgramState : public FFState {
  public:
		TargetNgramState(std::vector<Word> &words): m_words(words) {}
		const std::vector<Word> GetWords() const {return m_words;}
    virtual int Compare(const FFState& other) const;

  private:
    std::vector<Word> m_words;
};

class TargetNgramChartState : public FFState
{
private:
  Phrase m_contextPrefix, m_contextSuffix;

  size_t m_numTargetTerminals; // This isn't really correct except for the surviving hypothesis

  size_t m_startPos, m_endPos, m_inputSize;

  /** Construct the prefix string of up to specified size
   * \param ret prefix string
   * \param size maximum size (typically max lm context window)
   */
  size_t CalcPrefix(const ChartHypothesis &hypo, const int featureId, Phrase &ret, size_t size) const
  {
    const TargetPhrase &target = hypo.GetCurrTargetPhrase();
    const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
      target.GetAlignmentInfo().GetNonTermIndexMap();

    // loop over the rule that is being applied
    for (size_t pos = 0; pos < target.GetSize(); ++pos) {
      const Word &word = target.GetWord(pos);

      // for non-terminals, retrieve it from underlying hypothesis
      if (word.IsNonTerminal()) {
        size_t nonTermInd = nonTermIndexMap[pos];
        const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd);
        size = static_cast<const TargetNgramChartState*>(prevHypo->GetFFState(featureId))->CalcPrefix(*prevHypo, featureId, ret, size);
//        Phrase phrase = static_cast<const TargetNgramChartState*>(prevHypo->GetFFState(featureId))->GetPrefix();
//        size = phrase.GetSize();
      }
      // for words, add word
      else {
        ret.AddWord(word);
        size--;
      }

      // finish when maximum length reached
      if (size==0)
        break;
    }

    return size;
  }

  /** Construct the suffix phrase of up to specified size
   * will always be called after the construction of prefix phrase
   * \param ret suffix phrase
   * \param size maximum size of suffix
   */
  size_t CalcSuffix(const ChartHypothesis &hypo, int featureId, Phrase &ret, size_t size) const
  {
  	size_t prefixSize = m_contextPrefix.GetSize();
    assert(prefixSize <= m_numTargetTerminals);

    // special handling for small hypotheses
    // does the prefix match the entire hypothesis string? -> just copy prefix
    if (prefixSize == m_numTargetTerminals) {
      size_t maxCount = std::min(prefixSize, size);
      size_t pos= prefixSize - 1;

      for (size_t ind = 0; ind < maxCount; ++ind) {
        const Word &word = m_contextPrefix.GetWord(pos);
        ret.PrependWord(word);
        --pos;
      }

      size -= maxCount;
      return size;
    }
    // construct suffix analogous to prefix
    else {
    	const TargetPhrase targetPhrase = hypo.GetCurrTargetPhrase();
      const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
      		targetPhrase.GetAlignmentInfo().GetNonTermIndexMap();
      for (int pos = (int) targetPhrase.GetSize() - 1; pos >= 0 ; --pos) {
        const Word &word = targetPhrase.GetWord(pos);

        if (word.IsNonTerminal()) {
          size_t nonTermInd = nonTermIndexMap[pos];
          const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd);
          size = static_cast<const TargetNgramChartState*>(prevHypo->GetFFState(featureId))->CalcSuffix(*prevHypo, featureId, ret, size);
        }
        else {
          ret.PrependWord(word);
          size--;
        }

        if (size==0)
          break;
      }

      return size;
    }
  }

public:
  TargetNgramChartState(const ChartHypothesis &hypo, int featureId, size_t order)
      :m_contextPrefix(order - 1),
      m_contextSuffix(order - 1)
  {
    m_numTargetTerminals = hypo.GetCurrTargetPhrase().GetNumTerminals();
    const WordsRange range = hypo.GetCurrSourceRange();
    m_startPos = range.GetStartPos();
    m_endPos = range.GetEndPos();
    m_inputSize = hypo.GetManager().GetSource().GetSize();

    const std::vector<const ChartHypothesis*> prevHypos = hypo.GetPrevHypos();
    for (std::vector<const ChartHypothesis*>::const_iterator i = prevHypos.begin(); i != prevHypos.end(); ++i) {
      // keep count of words (= length of generated string)
      m_numTargetTerminals += static_cast<const TargetNgramChartState*>((*i)->GetFFState(featureId))->GetNumTargetTerminals();
    }

    CalcPrefix(hypo, featureId, m_contextPrefix, order - 1);
    CalcSuffix(hypo, featureId, m_contextSuffix, order - 1);
  }

  size_t GetNumTargetTerminals() const {
    return m_numTargetTerminals;
  }

  const Phrase &GetPrefix() const {
    return m_contextPrefix;
  }
  const Phrase &GetSuffix() const {
    return m_contextSuffix;
  }

  int Compare(const FFState& o) const {
    const TargetNgramChartState &other =
      static_cast<const TargetNgramChartState &>( o );

    // prefix
    if (m_startPos > 0) // not for "<s> ..."
    {
      int ret = GetPrefix().Compare(other.GetPrefix());
      if (ret != 0)
        return ret;
    }

    if (m_endPos < m_inputSize - 1)// not for "... </s>"
    {
      int ret = GetSuffix().Compare(other.GetSuffix());
      if (ret != 0)
        return ret;
    }
    return 0;
  }
};

/** Sets the features of observed ngrams.
 */
class TargetNgramFeature : public StatefulFeatureFunction {
public:
	TargetNgramFeature(FactorType factorType = 0, size_t n = 3, bool lower_ngrams = true):
     StatefulFeatureFunction("dlm", ScoreProducer::unlimited),
     m_factorType(factorType),
     m_n(n),
     m_lower_ngrams(lower_ngrams),
     m_sparseProducerWeight(1)
  {
    FactorCollection& factorCollection = FactorCollection::Instance();
    const Factor* bosFactor = factorCollection.AddFactor(Output,m_factorType,BOS_);
    m_bos.SetFactor(m_factorType,bosFactor);
    m_baseName = GetScoreProducerDescription();
    m_baseName.append("_");
  }

	bool Load(const std::string &filePath);

	std::string GetScoreProducerWeightShortName(unsigned) const;
	size_t GetNumInputScores() const;

  void SetSparseProducerWeight(float weight) { m_sparseProducerWeight = weight; }
  float GetSparseProducerWeight() const { return m_sparseProducerWeight; }

	virtual const FFState* EmptyHypothesisState(const InputType &input) const;

	virtual FFState* Evaluate(const Hypothesis& cur_hypo, const FFState* prev_state,
	                          ScoreComponentCollection* accumulator) const;

  virtual FFState* EvaluateChart(const ChartHypothesis& cur_hypo, int featureId,
                                  ScoreComponentCollection* accumulator) const;

private:
  FactorType m_factorType;
  Word m_bos;
	std::set<std::string> m_vocab;
	size_t m_n;
	bool m_lower_ngrams;

	// additional weight that all sparse weights are scaled with
	float m_sparseProducerWeight;

	std::string m_baseName;

	void appendNgram(const Word& word, bool& skip, std::stringstream& ngram) const;
	void MakePrefixNgrams(std::vector<const Word*> &contextFactor, ScoreComponentCollection* accumulator,
			      size_t numberOfStartPos = 1, size_t offset = 0) const;
	void MakeSuffixNgrams(std::vector<const Word*> &contextFactor, ScoreComponentCollection* accumulator,
			      size_t numberOfEndPos = 1, size_t offset = 0) const;
};

}

#endif // moses_TargetNgramFeature_h