Welcome to mirror list, hosted at ThFree Co, Russian Federation.

ChartState.h « LM « moses - github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: d4a5cfb30abf8609374624cec523955ccd4d14d1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#pragma once

#include "moses/FF/FFState.h"
#include "moses/ChartHypothesis.h"
#include "moses/ChartManager.h"

namespace Moses
{

class LanguageModelChartState : public FFState
{
private:
  float m_prefixScore;
  FFState* m_lmRightContext;

  Phrase m_contextPrefix, m_contextSuffix;

  size_t m_numTargetTerminals; // This isn't really correct except for the surviving hypothesis

  const ChartHypothesis &m_hypo;

  /** Construct the prefix string of up to specified size
   * \param ret prefix string
   * \param size maximum size (typically max lm context window)
   */
  size_t CalcPrefix(const ChartHypothesis &hypo, int featureID, Phrase &ret, size_t size) const {
    const TargetPhrase &target = hypo.GetCurrTargetPhrase();
    const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
      target.GetAlignNonTerm().GetNonTermIndexMap();

    // loop over the rule that is being applied
    for (size_t pos = 0; pos < target.GetSize(); ++pos) {
      const Word &word = target.GetWord(pos);

      // for non-terminals, retrieve it from underlying hypothesis
      if (word.IsNonTerminal()) {
        size_t nonTermInd = nonTermIndexMap[pos];
        const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd);
        size = static_cast<const LanguageModelChartState*>(prevHypo->GetFFState(featureID))->CalcPrefix(*prevHypo, featureID, ret, size);
      }
      // for words, add word
      else {
        ret.AddWord(target.GetWord(pos));
        size--;
      }

      // finish when maximum length reached
      if (size==0)
        break;
    }

    return size;
  }

  /** Construct the suffix phrase of up to specified size
   * will always be called after the construction of prefix phrase
   * \param ret suffix phrase
   * \param size maximum size of suffix
   */
  size_t CalcSuffix(const ChartHypothesis &hypo, int featureID, Phrase &ret, size_t size) const {
    UTIL_THROW_IF2(m_contextPrefix.GetSize() > m_numTargetTerminals, "Error");

    // special handling for small hypotheses
    // does the prefix match the entire hypothesis string? -> just copy prefix
    if (m_contextPrefix.GetSize() == m_numTargetTerminals) {
      size_t maxCount = std::min(m_contextPrefix.GetSize(), size);
      size_t pos= m_contextPrefix.GetSize() - 1;

      for (size_t ind = 0; ind < maxCount; ++ind) {
        const Word &word = m_contextPrefix.GetWord(pos);
        ret.PrependWord(word);
        --pos;
      }

      size -= maxCount;
      return size;
    }
    // construct suffix analogous to prefix
    else {
      const TargetPhrase& target = hypo.GetCurrTargetPhrase();
      const AlignmentInfo::NonTermIndexMap &nonTermIndexMap =
        target.GetAlignNonTerm().GetNonTermIndexMap();
      for (int pos = (int) target.GetSize() - 1; pos >= 0 ; --pos) {
        const Word &word = target.GetWord(pos);

        if (word.IsNonTerminal()) {
          size_t nonTermInd = nonTermIndexMap[pos];
          const ChartHypothesis *prevHypo = hypo.GetPrevHypo(nonTermInd);
          size = static_cast<const LanguageModelChartState*>(prevHypo->GetFFState(featureID))->CalcSuffix(*prevHypo, featureID, ret, size);
        } else {
          ret.PrependWord(hypo.GetCurrTargetPhrase().GetWord(pos));
          size--;
        }

        if (size==0)
          break;
      }

      return size;
    }
  }


public:
  LanguageModelChartState(const ChartHypothesis &hypo, int featureID, size_t order)
    :m_lmRightContext(NULL)
    ,m_contextPrefix(order - 1)
    ,m_contextSuffix( order - 1)
    ,m_hypo(hypo) {
    m_numTargetTerminals = hypo.GetCurrTargetPhrase().GetNumTerminals();

    for (std::vector<const ChartHypothesis*>::const_iterator i = hypo.GetPrevHypos().begin(); i != hypo.GetPrevHypos().end(); ++i) {
      // keep count of words (= length of generated string)
      m_numTargetTerminals += static_cast<const LanguageModelChartState*>((*i)->GetFFState(featureID))->GetNumTargetTerminals();
    }

    CalcPrefix(hypo, featureID, m_contextPrefix, order - 1);
    CalcSuffix(hypo, featureID, m_contextSuffix, order - 1);
  }

  ~LanguageModelChartState() {
    delete m_lmRightContext;
  }

  void Set(float prefixScore, FFState *rightState) {
    m_prefixScore = prefixScore;
    m_lmRightContext = rightState;
  }

  float GetPrefixScore() const {
    return m_prefixScore;
  }
  FFState* GetRightContext() const {
    return m_lmRightContext;
  }

  size_t GetNumTargetTerminals() const {
    return m_numTargetTerminals;
  }

  const Phrase &GetPrefix() const {
    return m_contextPrefix;
  }
  const Phrase &GetSuffix() const {
    return m_contextSuffix;
  }

  size_t hash() const {
    size_t ret;

    // prefix
    ret = m_hypo.GetCurrSourceRange().GetStartPos() > 0;
    if (m_hypo.GetCurrSourceRange().GetStartPos() > 0) { // not for "<s> ..."
      size_t hash = hash_value(GetPrefix());
      boost::hash_combine(ret, hash);
    }

    // suffix
    size_t inputSize = m_hypo.GetManager().GetSource().GetSize();
    boost::hash_combine(ret, m_hypo.GetCurrSourceRange().GetEndPos() < inputSize - 1);
    if (m_hypo.GetCurrSourceRange().GetEndPos() < inputSize - 1) { // not for "... </s>"
      size_t hash = m_lmRightContext->hash();
      boost::hash_combine(ret, hash);
    }

    return ret;
  }
  virtual bool operator==(const FFState& o) const {
    const LanguageModelChartState &other =
      static_cast<const LanguageModelChartState &>( o );

    // prefix
    if (m_hypo.GetCurrSourceRange().GetStartPos() > 0) { // not for "<s> ..."
      bool ret = GetPrefix() == other.GetPrefix();
      if (ret == false)
        return false;
    }

    // suffix
    size_t inputSize = m_hypo.GetManager().GetSource().GetSize();
    if (m_hypo.GetCurrSourceRange().GetEndPos() < inputSize - 1) { // not for "... </s>"
      bool ret = (*other.GetRightContext()) == (*m_lmRightContext);
      return ret;
    }
    return true;
  }

};

} // namespace