Welcome to mirror list, hosted at ThFree Co, Russian Federation.

ForestRescore.cpp « mert - github.com/moses-smt/mosesdecoder.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: c88b58e4ca4c9146ac7e55fd8c694971460058f1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
/***********************************************************************
Moses - factored phrase-based language decoder
Copyright (C) 2014- University of Edinburgh

This library is free software; you can redistribute it and/or
modify it under the terms of the GNU Lesser General Public
License as published by the Free Software Foundation; either
version 2.1 of the License, or (at your option) any later version.

This library is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
Lesser General Public License for more details.

You should have received a copy of the GNU Lesser General Public
License along with this library; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
***********************************************************************/

#include <cmath>
#include <limits>
#include <list>

#include <boost/unordered_set.hpp>

#include "util/file_piece.hh"
#include "util/tokenize_piece.hh"

#include "BleuScorer.h"
#include "ForestRescore.h"

using namespace std;

namespace MosesTuning {

std::ostream& operator<<(std::ostream& out, const WordVec& wordVec) {
  out << "[";
  for (size_t i = 0; i < wordVec.size(); ++i) {
    out << wordVec[i]->first;
    if (i+1< wordVec.size()) out << " ";
  }
  out << "]";
  return out;
}


void ReferenceSet::Load(const vector<string>& files, Vocab& vocab) {
  for (size_t i = 0; i < files.size(); ++i) {
    util::FilePiece fh(files[i].c_str());
    size_t sentenceId = 0;
    while(true) {
      StringPiece line;
      try {
        line = fh.ReadLine();
      } catch (util::EndOfFileException &e) {
        break;
      }
     AddLine(sentenceId, line, vocab);
     ++sentenceId;
    }
  }

}

void ReferenceSet::AddLine(size_t sentenceId, const StringPiece& line, Vocab& vocab) {
  //cerr << line << endl;
  NgramCounter ngramCounts;
  list<WordVec> openNgrams;
  size_t length = 0;
  //tokenize & count
  for (util::TokenIter<util::SingleCharacter, true> j(line, util::SingleCharacter(' ')); j; ++j) {
    const Vocab::Entry* nextTok = &(vocab.FindOrAdd(*j));
    ++length;
    openNgrams.push_front(WordVec());
    for (list<WordVec>::iterator k = openNgrams.begin(); k != openNgrams.end();  ++k) {
      k->push_back(nextTok);
      ++ngramCounts[*k]; 
    }
    if (openNgrams.size() >=  kBleuNgramOrder) openNgrams.pop_back();
  }

  //merge into overall ngram map
  for (NgramCounter::const_iterator ni = ngramCounts.begin();
    ni != ngramCounts.end(); ++ni) {
    size_t count = ni->second;
    //cerr << *ni << " " << count <<  endl;
    if (ngramCounts_.size() <= sentenceId) ngramCounts_.resize(sentenceId+1);
    NgramMap::iterator totalsIter = ngramCounts_[sentenceId].find(ni->first);
    if (totalsIter == ngramCounts_[sentenceId].end()) {
      ngramCounts_[sentenceId][ni->first] = pair<size_t,size_t>(count,count);
    } else {
      ngramCounts_[sentenceId][ni->first].first = max(count, ngramCounts_[sentenceId][ni->first].first); //clip
      ngramCounts_[sentenceId][ni->first].second += count; //no clip
    }
  }
  //length
  if (lengths_.size() <= sentenceId) lengths_.resize(sentenceId+1);
  //TODO - length strategy - this is MIN
  if (!lengths_[sentenceId]) {
    lengths_[sentenceId] = length;
  } else {
    lengths_[sentenceId] = min(length,lengths_[sentenceId]);
  }
  //cerr << endl;

}
  
size_t ReferenceSet::NgramMatches(size_t sentenceId, const WordVec& ngram, bool clip) const  {
  const NgramMap& ngramCounts = ngramCounts_.at(sentenceId);
  NgramMap::const_iterator ngi = ngramCounts.find(ngram);
  if (ngi == ngramCounts.end()) return 0;
  return clip ? ngi->second.first : ngi->second.second;
}

VertexState::VertexState(): bleuStats(kBleuNgramOrder), targetLength(0) {}

void HgBleuScorer::UpdateMatches(const NgramCounter& counts, vector<FeatureStatsType>& bleuStats ) const {
  for (NgramCounter::const_iterator ngi = counts.begin(); ngi != counts.end(); ++ngi) {
    //cerr << "Checking: " << *ngi << " matches " << references_.NgramMatches(sentenceId_,*ngi,false) <<  endl;
    size_t order = ngi->first.size();
    size_t count = ngi->second;
    bleuStats[(order-1)*2 + 1] += count;
    bleuStats[(order-1) * 2] += min(count, references_.NgramMatches(sentenceId_,ngi->first,false));
  }
}

size_t HgBleuScorer::GetTargetLength(const Edge& edge) const {
  size_t targetLength = 0;
  for (size_t i = 0; i < edge.Words().size(); ++i) {
    const Vocab::Entry* word = edge.Words()[i];
    if (word) ++targetLength;
  }
  for (size_t i = 0; i < edge.Children().size(); ++i) {
    const VertexState& state = vertexStates_[edge.Children()[i]];
    targetLength += state.targetLength;
  }
  return targetLength;
}

FeatureStatsType HgBleuScorer::Score(const Edge& edge, const Vertex& head, vector<FeatureStatsType>& bleuStats) {
  NgramCounter ngramCounts;
  size_t childId = 0;
  size_t wordId = 0;
  size_t contextId = 0; //position within left or right context
  const VertexState* vertexState = NULL;
  bool inLeftContext = false;
  bool inRightContext = false;
  list<WordVec> openNgrams;
  const Vocab::Entry* currentWord = NULL;
  while (wordId < edge.Words().size()) { 
    currentWord = edge.Words()[wordId];
    if (currentWord != NULL) {
      ++wordId;
    } else {
      if (!inLeftContext && !inRightContext) {
        //entering a vertex
        assert(!vertexState);
        vertexState = &(vertexStates_[edge.Children()[childId]]);
        ++childId;
        if (vertexState->leftContext.size()) {
          inLeftContext = true;
          contextId = 0;
          currentWord = vertexState->leftContext[contextId];
        } else {
          //empty context
          vertexState = NULL;
          ++wordId;
          continue;
        }
      } else {
        //already in a vertex
        ++contextId;
        if (inLeftContext && contextId < vertexState->leftContext.size()) {
          //still in left context
          currentWord = vertexState->leftContext[contextId];
        } else if (inLeftContext) {
          //at end of left context
          if (vertexState->leftContext.size() == kBleuNgramOrder-1) {
            //full size context, jump to right state
            openNgrams.clear();
            inLeftContext = false;
            inRightContext = true;
            contextId = 0;
            currentWord = vertexState->rightContext[contextId];
          } else {
            //short context, just ignore right context
            inLeftContext = false;
            vertexState = NULL;
            ++wordId;
            continue;
          }
        } else {
          //in right context
          if (contextId < vertexState->rightContext.size()) {
            currentWord = vertexState->rightContext[contextId];
          } else {
            //leaving vertex
            inRightContext = false;
            vertexState = NULL;
            ++wordId;
            continue;
          }
        }
      }
    }
    assert(currentWord);
    if (graph_.IsBoundary(currentWord)) continue;
    openNgrams.push_front(WordVec());
    openNgrams.front().reserve(kBleuNgramOrder);
    for (list<WordVec>::iterator k = openNgrams.begin(); k != openNgrams.end();  ++k) {
      k->push_back(currentWord);
      //Only insert ngrams that cross boundaries
      if (!vertexState || (inLeftContext && k->size() > contextId+1)) ++ngramCounts[*k];
    }
    if (openNgrams.size() >=  kBleuNgramOrder) openNgrams.pop_back();
  }
  
  //Collect matches
  //This edge
  //cerr << "edge ngrams" << endl;
  UpdateMatches(ngramCounts, bleuStats);

  //Child vertexes
  for (size_t i = 0; i < edge.Children().size(); ++i) {
    //cerr << "vertex ngrams " << edge.Children()[i] << endl;
    for (size_t j = 0; j < bleuStats.size(); ++j) {
      bleuStats[j] += vertexStates_[edge.Children()[i]].bleuStats[j];
    }
  }
  

  FeatureStatsType sourceLength = head.SourceCovered();
  size_t referenceLength = references_.Length(sentenceId_);
  FeatureStatsType effectiveReferenceLength = 
    sourceLength / totalSourceLength_ * referenceLength;

  bleuStats[bleuStats.size()-1] = effectiveReferenceLength;
  //backgroundBleu_[backgroundBleu_.size()-1] = 
  //  backgroundRefLength_ * sourceLength / totalSourceLength_;
  FeatureStatsType bleu = sentenceLevelBackgroundBleu(bleuStats, backgroundBleu_);

  return bleu;
}

void HgBleuScorer::UpdateState(const Edge& winnerEdge, size_t vertexId, const vector<FeatureStatsType>& bleuStats) {
  //TODO: Maybe more efficient to absorb into the Score() method
  VertexState& vertexState = vertexStates_[vertexId];
  //cerr << "Updating state for " << vertexId << endl;
  
  //leftContext
  int wi = 0;
  const VertexState* childState = NULL;
  int contexti = 0; //index within child context
  int childi = 0;
  while (vertexState.leftContext.size() < (kBleuNgramOrder-1)) {
    if ((size_t)wi >= winnerEdge.Words().size()) break;
    const Vocab::Entry* word = winnerEdge.Words()[wi];
    if (word != NULL) {
      vertexState.leftContext.push_back(word);
      ++wi;
    } else {
      if (childState == NULL) {
        //start of child state
        childState = &(vertexStates_[winnerEdge.Children()[childi++]]);
        contexti = 0;
      } 
      if ((size_t)contexti < childState->leftContext.size()) {
        vertexState.leftContext.push_back(childState->leftContext[contexti++]); 
      } else {
        //end of child context
        childState = NULL;
        ++wi;
      }
    }
  }

  //rightContext
  wi = winnerEdge.Words().size() - 1;
  childState = NULL;
  childi = winnerEdge.Children().size() - 1;
  while (vertexState.rightContext.size() < (kBleuNgramOrder-1)) {
    if (wi < 0) break;
    const Vocab::Entry* word = winnerEdge.Words()[wi];
    if (word != NULL) {
      vertexState.rightContext.push_back(word);
      --wi;
    } else {
      if (childState == NULL) {
        //start (ie rhs) of child state
        childState = &(vertexStates_[winnerEdge.Children()[childi--]]);
        contexti = childState->rightContext.size()-1;
      }
      if (contexti >= 0) {
        vertexState.rightContext.push_back(childState->rightContext[contexti--]);
      } else {
        //end (ie lhs) of child context
        childState = NULL;
        --wi;
      }
    }
  }
  reverse(vertexState.rightContext.begin(), vertexState.rightContext.end());

  //length + counts
  vertexState.targetLength = GetTargetLength(winnerEdge);
  vertexState.bleuStats = bleuStats;
}


typedef pair<const Edge*,FeatureStatsType> BackPointer;


/**
 * Recurse through back pointers
 **/
static void GetBestHypothesis(size_t vertexId, const Graph& graph, const vector<BackPointer>& bps,
     HgHypothesis* bestHypo) {
  //cerr << "Expanding " << vertexId << endl;
  //UTIL_THROW_IF(bps[vertexId].second == kMinScore+1, HypergraphException, "Landed at vertex " << vertexId << " which is a dead end");
  if (!bps[vertexId].first) return;
  const Edge* prevEdge = bps[vertexId].first;
  bestHypo->featureVector += *(prevEdge->Features().get());
  size_t childId = 0;
  for (size_t i = 0; i < prevEdge->Words().size(); ++i) {
    if (prevEdge->Words()[i] != NULL) {
      bestHypo->text.push_back(prevEdge->Words()[i]);
    } else {
      size_t childVertexId = prevEdge->Children()[childId++];
      HgHypothesis childHypo;
      GetBestHypothesis(childVertexId,graph,bps,&childHypo);
      bestHypo->text.insert(bestHypo->text.end(), childHypo.text.begin(), childHypo.text.end());
      bestHypo->featureVector += childHypo.featureVector;
    }
  }
}

void Viterbi(const Graph& graph, const SparseVector& weights, float bleuWeight, const ReferenceSet& references , size_t sentenceId, const std::vector<FeatureStatsType>& backgroundBleu,  HgHypothesis* bestHypo) 
{
  BackPointer init(NULL,kMinScore);
  vector<BackPointer> backPointers(graph.VertexSize(),init);
  HgBleuScorer bleuScorer(references, graph, sentenceId, backgroundBleu);
  vector<FeatureStatsType> winnerStats(kBleuNgramOrder*2+1);
  for (size_t vi = 0; vi < graph.VertexSize(); ++vi) {
    //cerr << "vertex id " << vi <<  endl;
    FeatureStatsType winnerScore = kMinScore;
    const Vertex& vertex = graph.GetVertex(vi);
    const vector<const Edge*>& incoming = vertex.GetIncoming();
    if (!incoming.size()) {
      //UTIL_THROW(HypergraphException, "Vertex " << vi << " has no incoming edges");
      //If no incoming edges, vertex is a dead end
      backPointers[vi].first = NULL;
      backPointers[vi].second = kMinScore/2;  
    } else {
      //cerr << "\nVertex: " << vi << endl;
      for (size_t ei = 0; ei < incoming.size(); ++ei) {
        //cerr << "edge id " << ei << endl;
        FeatureStatsType incomingScore = incoming[ei]->GetScore(weights);
        for (size_t i = 0; i < incoming[ei]->Children().size(); ++i) {
          size_t childId = incoming[ei]->Children()[i];
          UTIL_THROW_IF(backPointers[childId].second == kMinScore,
            HypergraphException, "Graph was not topologically sorted. curr=" << vi << " prev=" << childId);
          incomingScore += backPointers[childId].second;
        }
        vector<FeatureStatsType> bleuStats(kBleuNgramOrder*2+1);
       // cerr << "Score: " << incomingScore << " Bleu: ";
       // if (incomingScore > nonbleuscore) {nonbleuscore = incomingScore; nonbleuid = ei;}
        FeatureStatsType totalScore = incomingScore;
        if (bleuWeight) { 
          FeatureStatsType bleuScore = bleuScorer.Score(*(incoming[ei]), vertex, bleuStats);
          UTIL_THROW_IF(isnan(bleuScore), util::Exception, "Bleu score undefined, smoothing problem?");
          totalScore += bleuWeight * bleuScore;
        //  cerr << bleuScore << " Total: " << incomingScore << endl << endl;
          //cerr << "is " << incomingScore << " bs " << bleuScore << endl;
        }
        if (totalScore >= winnerScore) {
          //We only store the feature score (not the bleu score) with the vertex,
          //since the bleu score is always cumulative, ie from counts for the whole span.
          winnerScore = totalScore;
          backPointers[vi].first = incoming[ei];
          backPointers[vi].second = incomingScore;
          winnerStats = bleuStats;
        }
      }
      //update with winner
      //if (bleuWeight) {
      //TODO: Not sure if we need this when computing max-model solution
      bleuScorer.UpdateState(*(backPointers[vi].first), vi, winnerStats);

    }
  }

  //expand back pointers
  GetBestHypothesis(graph.VertexSize()-1, graph, backPointers, bestHypo);

  //bleu stats and fv

  //Need the actual (clipped) stats
  //TODO: This repeats code in bleu scorer - factor out
  bestHypo->bleuStats.resize(kBleuNgramOrder*2+1);
  NgramCounter counts;
  list<WordVec> openNgrams;
  for (size_t i = 0; i < bestHypo->text.size(); ++i) {
    const Vocab::Entry* entry = bestHypo->text[i];
    if (graph.IsBoundary(entry)) continue;
    openNgrams.push_front(WordVec());
    for (list<WordVec>::iterator k = openNgrams.begin(); k != openNgrams.end();  ++k) {
      k->push_back(entry);
      ++counts[*k];
    }
    if (openNgrams.size() >=  kBleuNgramOrder) openNgrams.pop_back();
  }
  for (NgramCounter::const_iterator ngi = counts.begin(); ngi != counts.end(); ++ngi) {
    size_t order = ngi->first.size();
    size_t count = ngi->second;
    bestHypo->bleuStats[(order-1)*2 + 1] += count;
    bestHypo->bleuStats[(order-1) * 2] += min(count, references.NgramMatches(sentenceId,ngi->first,true));
  }
  bestHypo->bleuStats[kBleuNgramOrder*2] = references.Length(sentenceId);
}


};