Welcome to mirror list, hosted at ThFree Co, Russian Federation.

Model1Feature.cpp « FF « moses - github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 38883c12e8599d66e22dde65a1cd0d87dfb2b74b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
#include "Model1Feature.h"
#include "moses/StaticData.h"
#include "moses/InputFileStream.h"
#include "moses/ScoreComponentCollection.h"
#include "moses/FactorCollection.h"


using namespace std;

namespace Moses
{

const std::string Model1Vocabulary::GIZANULL = "GIZANULL";

Model1Vocabulary::Model1Vocabulary()
{
  FactorCollection &factorCollection = FactorCollection::Instance();
  m_NULL = factorCollection.AddFactor(GIZANULL,false);
  Store(m_NULL,0);
}

bool Model1Vocabulary::Store(const Factor* word, const unsigned id) 
{
  boost::unordered_map<const Factor*, unsigned>::iterator iter = m_lookup.find( word );
  if ( iter != m_lookup.end() ) {
    return false;
  }
  m_lookup[ word ] = id;
  if ( m_vocab.size() <= id ) {
    m_vocab.resize(id+1);
  }
  m_vocab[id] = word;
  return true;
}

unsigned Model1Vocabulary::StoreIfNew(const Factor* word) 
{
  boost::unordered_map<const Factor*, unsigned>::iterator iter = m_lookup.find( word );

  if ( iter != m_lookup.end() ) {
    return iter->second;
  }

  unsigned id = m_vocab.size();
  m_vocab.push_back( word );
  m_lookup[ word ] = id;
  return id;
}

unsigned Model1Vocabulary::GetWordID(const Factor* word) const 
{
  boost::unordered_map<const Factor*, unsigned>::const_iterator iter = m_lookup.find( word );
  if ( iter == m_lookup.end() ) {
    return INVALID_ID;
  }
  return iter->second;
}

const Factor* Model1Vocabulary::GetWord(unsigned id) const 
{
  if (id >= m_vocab.size()) {
    return NULL;
  }
  return m_vocab[ id ];
}

void Model1Vocabulary::Load(const std::string& fileName) 
{
  InputFileStream inFile(fileName);
  FactorCollection &factorCollection = FactorCollection::Instance();
  std::string line;

  unsigned i = 0;
  if ( getline(inFile, line) ) // first line of MGIZA vocabulary files seems to be special : "1       UNK     0"  -- skip if it's this
  {
    ++i;
    std::vector<std::string> tokens = Tokenize(line);
    UTIL_THROW_IF2(tokens.size()!=3, "Line " << i << " in " << fileName << " has wrong number of tokens.");
    unsigned id = Scan<unsigned>(tokens[0]);
    if (! ( (id == 1) && (tokens[1] == "UNK") ))
    {
      const Factor* factor = factorCollection.AddFactor(tokens[1],false); // TODO: can we assume that the vocabulary is know and filter the model on loading?
      bool stored = Store(factor, id);
      UTIL_THROW_IF2(!stored, "Line " << i << " in " << fileName << " overwrites existing vocabulary entry.");
    }
  }
  while ( getline(inFile, line) ) 
  {
    ++i;
    std::vector<std::string> tokens = Tokenize(line);
    UTIL_THROW_IF2(tokens.size()!=3, "Line " << i << " in " << fileName << " has wrong number of tokens.");
    unsigned id = Scan<unsigned>(tokens[0]);
    const Factor* factor = factorCollection.AddFactor(tokens[1],false); // TODO: can we assume that the vocabulary is know and filter the model on loading?
    bool stored = Store(factor, id);
    UTIL_THROW_IF2(!stored, "Line " << i << " in " << fileName << " overwrites existing vocabulary entry.");
  }
  inFile.Close();
}


void Model1LexicalTable::Load(const std::string &fileName, const Model1Vocabulary& vcbS, const Model1Vocabulary& vcbT)
{
  InputFileStream inFile(fileName);
  std::string line;

  unsigned i = 0;
  while ( getline(inFile, line) ) 
  {
    ++i;
    std::vector<std::string> tokens = Tokenize(line);
    UTIL_THROW_IF2(tokens.size()!=3, "Line " << i << " in " << fileName << " has wrong number of tokens.");
    unsigned idS = Scan<unsigned>(tokens[0]);
    unsigned idT = Scan<unsigned>(tokens[1]);
    const Factor* wordS = vcbS.GetWord(idS);
    const Factor* wordT = vcbT.GetWord(idT);
    float prob = Scan<float>(tokens[2]);
    if ( (wordS != NULL) && (wordT != NULL) ) {
      m_ltable[ wordS ][ wordT ] = prob;
    }
    UTIL_THROW_IF2((wordS == NULL) || (wordT == NULL), "Line " << i << " in " << fileName << " has unknown vocabulary."); // TODO: can we assume that the vocabulary is know and filter the model on loading? Then remove this line.
  }
  inFile.Close();
}

// p( wordT | wordS )
float Model1LexicalTable::GetProbability(const Factor* wordS, const Factor* wordT) const
{
  float prob = m_floor;
 
  boost::unordered_map< const Factor*, boost::unordered_map< const Factor*, float > >::const_iterator iter1 = m_ltable.find( wordS ); 

  if ( iter1 != m_ltable.end() ) {
    boost::unordered_map< const Factor*, float >::const_iterator iter2 = iter1->second.find( wordT );
    if ( iter2 != iter1->second.end() ) {
      prob = iter2->second;
      if ( prob < m_floor ) {
        prob = m_floor;
      }
    }
  }
  return prob;
}


Model1Feature::Model1Feature(const std::string &line)
  : StatelessFeatureFunction(1, line)
{
  VERBOSE(1, "Initializing feature " << GetScoreProducerDescription() << " ...");
  ReadParameters();
  VERBOSE(1, " Done.");
}

void Model1Feature::SetParameter(const std::string& key, const std::string& value)
{
  if (key == "path") {
    m_fileNameModel1 = value;
  } else if (key == "sourceVocabulary") {
    m_fileNameVcbS = value;
  } else if (key == "targetVocabulary") {
    m_fileNameVcbT = value;
  } else {
    StatelessFeatureFunction::SetParameter(key, value);
  }
}

void Model1Feature::Load()
{
  FEATUREVERBOSE(2, GetScoreProducerDescription() << ": Loading source vocabulary from file " << m_fileNameVcbS << " ...");
  Model1Vocabulary vcbS;
  vcbS.Load(m_fileNameVcbS);
  FEATUREVERBOSE2(2, " Done." << std::endl);
  FEATUREVERBOSE(2, GetScoreProducerDescription() << ": Loading target vocabulary from file " << m_fileNameVcbT << " ...");
  Model1Vocabulary vcbT;
  vcbT.Load(m_fileNameVcbT);
  FEATUREVERBOSE2(2, " Done." << std::endl);
  FEATUREVERBOSE(2, GetScoreProducerDescription() << ": Loading model 1 lexical translation table from file " << m_fileNameModel1 << " ...");
  m_model1.Load(m_fileNameModel1,vcbS,vcbT);
  FEATUREVERBOSE2(2, " Done." << std::endl);
  FactorCollection &factorCollection = FactorCollection::Instance();
  m_emptyWord = factorCollection.GetFactor(Model1Vocabulary::GIZANULL,false);
  UTIL_THROW_IF2(m_emptyWord==NULL, GetScoreProducerDescription()
                 << ": Factor for GIZA empty word does not exist.");
}

void Model1Feature::EvaluateWithSourceContext(const InputType &input
                                 , const InputPath &inputPath
                                 , const TargetPhrase &targetPhrase
                                 , const StackVec *stackVec
                                 , ScoreComponentCollection &scoreBreakdown
                                 , ScoreComponentCollection *estimatedFutureScore) const
{
  const Sentence& sentence = static_cast<const Sentence&>(input);
  float score = 0.0;
  float norm = TransformScore(1+sentence.GetSize());

  for (size_t posT=0; posT<targetPhrase.GetSize(); ++posT) 
  {
    const Word &wordT = targetPhrase.GetWord(posT);
    if ( !wordT.IsNonTerminal() ) 
    {
      float thisWordProb = m_model1.GetProbability(m_emptyWord,wordT[0]); // probability conditioned on empty word

      // cache lookup
      bool foundInCache = false;
      {
        #ifdef WITH_THREADS
        boost::shared_lock<boost::shared_mutex> read_lock(m_accessLock);
        #endif
        boost::unordered_map<const InputType*, boost::unordered_map<const Factor*, float> >::const_iterator sentenceCache = m_cache.find(&input);
        if (sentenceCache != m_cache.end())
        {
          boost::unordered_map<const Factor*, float>::const_iterator cacheHit = sentenceCache->second.find(wordT[0]);
          if (cacheHit != sentenceCache->second.end())
          {
            foundInCache = true;
            score += cacheHit->second;
            FEATUREVERBOSE(3, "Cached score( " << wordT << " ) = " << cacheHit->second << std::endl);
          }
        }
      }

      if (!foundInCache)
      {
        for (size_t posS=1; posS<sentence.GetSize()-1; ++posS) // ignore <s> and </s>
        {
          const Word &wordS = sentence.GetWord(posS);
          float modelProb = m_model1.GetProbability(wordS[0],wordT[0]);
          FEATUREVERBOSE(4, "p( " << wordT << " | " << wordS << " ) = " << modelProb << std::endl);
          thisWordProb += modelProb;
        }
        float thisWordScore = TransformScore(thisWordProb) - norm;
        FEATUREVERBOSE(3, "score( " << wordT << " ) = " << thisWordScore << std::endl);
        {
          #ifdef WITH_THREADS 
          // need to update cache; write lock
          boost::unique_lock<boost::shared_mutex> lock(m_accessLock);
          #endif
          m_cache[&input][wordT[0]] = thisWordScore;
        }
        score += thisWordScore;
      }
    }
  } 

  scoreBreakdown.PlusEquals(this, score);
}
  
void Model1Feature::CleanUpAfterSentenceProcessing(const InputType& source) 
{
  #ifdef WITH_THREADS 
  // need to update cache; write lock
  boost::unique_lock<boost::shared_mutex> lock(m_accessLock);
  #endif
  // clear cache
  boost::unordered_map<const InputType*, boost::unordered_map<const Factor*, float> >::iterator sentenceCache = m_cache.find(&source);
  if (sentenceCache != m_cache.end())
  {
    sentenceCache->second.clear();
    m_cache.erase(sentenceCache);
  }
}

}