Welcome to mirror list, hosted at ThFree Co, Russian Federation.

beam_search.h « translator « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 60659a22a03f1ca7766480dc643879cad01b4d99 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
#pragma once
#include <algorithm>

#include "marian.h"
#include "translator/history.h"
#include "translator/scorers.h"

#include "translator/helpers.h"
#include "translator/nth_element.h"

namespace marian {

class BeamSearch {
private:
  Ptr<Options> options_;
  std::vector<Ptr<Scorer>> scorers_;
  size_t beamSize_;
  Word trgEosId_ = (Word)-1;
  Word trgUnkId_ = (Word)-1;

  static constexpr auto INVALID_PATH_SCORE = -9999;

public:
  BeamSearch(Ptr<Options> options,
             const std::vector<Ptr<Scorer>>& scorers,
             Word trgEosId,
             Word trgUnkId = -1)
      : options_(options),
        scorers_(scorers),
        beamSize_(options_->has("beam-size")
                      ? options_->get<size_t>("beam-size")
                      : 3),
        trgEosId_(trgEosId),
        trgUnkId_(trgUnkId) {}

  // combine new expandedPathScores and previous beams into new set of beams
  Beams toHyps(const std::vector<unsigned int>& nBestKeys, // [dimBatch, beamSize] flattened -> ((batchIdx, beamHypIdx) flattened, word idx) flattened
               const std::vector<float>& nBestPathScores,  // [dimBatch, beamSize] flattened
               const size_t vocabSize,
               const Beams& beams,
               const std::vector<Ptr<ScorerState /*const*/>>& states,
               const size_t beamSize,
               const bool first,
               Ptr<data::CorpusBatch /*const*/> batch) const {
    std::vector<float> align;
    if(options_->hasAndNotEmpty("alignment"))
      align = scorers_[0]->getAlignment(); // use alignments from the first scorer, even if ensemble

    const auto dimBatch = beams.size();
    Beams newBeams(dimBatch);

    for(size_t i = 0; i < nBestKeys.size(); ++i) { // [dimBatch, beamSize] flattened
      // Keys encode batchIdx, beamHypIdx, and word index in the entire beam.
      // They can be between 0 and beamSize * vocabSize-1.
      const auto  key       = nBestKeys[i];
      const float pathScore = nBestPathScores[i]; // expanded path score for (batchIdx, beamHypIdx, word)

      // decompose key into individual indices (batchIdx, beamHypIdx, wordIdx)
      const auto wordIdx    = (Word)(key % vocabSize);
      const auto beamHypIdx =       (key / vocabSize) % (first ? 1 : beamSize);
      const auto batchIdx   =       (key / vocabSize) / (first ? 1 : beamSize);

      ABORT_IF(i / beamSize != batchIdx, "Inconsistent batchIdx value in key??");

      const auto& beam = beams[batchIdx];
      auto& newBeam = newBeams[batchIdx];

      if (newBeam.size() >= beam.size()) // @TODO: Why this condition? It does happen. Why?
        continue;
      if (pathScore <= INVALID_PATH_SCORE) // (unused slot)
        continue;

      ABORT_IF(beamHypIdx >= beam.size(), "Out of bounds beamHypIdx??");

      // Map wordIdx to word
      Word word;
      // If short list has been set, then wordIdx is an index into the short-listed word set,
      // rather than the true word index.
      auto shortlist = scorers_[0]->getShortlist();
      if (shortlist)
        word = shortlist->reverseMap(wordIdx);
      else
        word = wordIdx;

      auto hyp = New<Hypothesis>(beam[beamHypIdx], word, (IndexType)beamHypIdx, pathScore);

      // Set score breakdown for n-best lists
      if(options_->get<bool>("n-best")) {
        std::vector<float> breakDown(states.size(), 0);
        beam[beamHypIdx]->getScoreBreakdown().resize(states.size(), 0); // @TODO: Why? Can we just guard the read-out below, then make it const? Or getScoreBreakdown(j)?
        for(size_t j = 0; j < states.size(); ++j) {
          size_t flattenedLogitIndex = (beamHypIdx * dimBatch + batchIdx) * vocabSize + wordIdx; // (beam idx, batch idx, word idx); note: beam and batch are transposed, compared to 'key'
          breakDown[j] = states[j]->breakDown(flattenedLogitIndex) + beam[beamHypIdx]->getScoreBreakdown()[j];
          // @TODO: pass those 3 indices directly into breakDown (state knows the dimensions)
        }
        hyp->setScoreBreakdown(breakDown);
      }

      // Set alignments
      if(!align.empty()) {
        hyp->setAlignment(getAlignmentsForHypothesis(align, batch, (int)beamHypIdx, (int)batchIdx));
      }

      newBeam.push_back(hyp);
    }
    return newBeams;
  }

  std::vector<float> getAlignmentsForHypothesis(
      const std::vector<float> alignAll,
      Ptr<data::CorpusBatch> batch,
      int beamHypIdx,
      int beamIdx) const {
    // Let's B be the beam size, N be the number of batched sentences,
    // and L the number of words in the longest sentence in the batch.
    // The alignment vector:
    //
    // if(first)
    //   * has length of N x L if it's the first beam
    //   * stores elements in the following order:
    //     beam1 = [word1-batch1, word1-batch2, ..., word2-batch1, ...]
    // else
    //   * has length of N x L x B
    //   * stores elements in the following order:
    //     beams = [beam1, beam2, ..., beam_n]
    //
    // The mask vector is always of length N x L and has 1/0s stored like
    // in a single beam, i.e.:
    //   * [word1-batch1, word1-batch2, ..., word2-batch1, ...]
    //
    size_t batchSize = batch->size();
    size_t batchWidth = batch->width() * batchSize;
    std::vector<float> align;

    for(size_t w = 0; w < batchWidth / batchSize; ++w) {
      size_t a = ((batchWidth * beamHypIdx) + beamIdx) + (batchSize * w);
      size_t m = a % batchWidth;
      if(batch->front()->mask()[m] != 0)
        align.emplace_back(alignAll[a]);
    }

    return align;
  }

  // remove all beam entries that have reached EOS
  Beams purgeBeams(const Beams& beams) {
    Beams newBeams;
    for(auto beam : beams) {
      Beam newBeam;
      for(auto hyp : beam) {
        if(hyp->getWord() != trgEosId_) {
          newBeam.push_back(hyp);
        }
      }
      newBeams.push_back(newBeam);
    }
    return newBeams;
  }

  //**********************************************************************
  // main decoding function
  Histories search(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch) {
    ABORT_IF(batch->back()->vocab() && batch->back()->vocab()->getEosId() != trgEosId_,
        "Batch uses different EOS token than was passed to BeamSearch originally");

    const int dimBatch = (int)batch->size();

    auto getNBestList = createGetNBestListFn(beamSize_, dimBatch, graph->getDeviceId());

    for(auto scorer : scorers_) {
      scorer->clear(graph);
    }

    Histories histories(dimBatch);
    for(int i = 0; i < dimBatch; ++i) {
      size_t sentId = batch->getSentenceIds()[i];
      histories[i] = New<History>(sentId,
                                  options_->get<float>("normalize"),
                                  options_->get<float>("word-penalty"));
    }

    // start states
    std::vector<Ptr<ScorerState>> states;
    for(auto scorer : scorers_) {
      states.push_back(scorer->startState(graph, batch));
    }

    Beams beams(dimBatch, Beam(beamSize_, New<Hypothesis>())); // array [dimBatch] of array [localBeamSize] of Hypothesis
    //Beams beams(dimBatch); // array [dimBatch] of array [localBeamSize] of Hypothesis
    //for(auto& beam : beams)
    //  beam.resize(beamSize_, New<Hypothesis>());

    for(int i = 0; i < dimBatch; ++i)
      histories[i]->add(beams[i], trgEosId_);

    // the decoder updates the following state information in each output time step:
    //  - beams: array [dimBatch] of array [localBeamSize] of Hypothesis
    //     - current output time step's set of active hypotheses, aka active search space
    //  - states[.]: ScorerState
    //     - NN state; one per scorer, e.g. 2 for ensemble of 2
    // and it forms the following return value
    //  - histories: array [dimBatch] of History
    //    with History: vector [t] of array [localBeamSize] of Hypothesis
    //    with Hypothesis: (last word, aggregate score, prev Hypothesis)

    // main loop over output time steps
    for (size_t t = 0; ; t++) {
      ABORT_IF(dimBatch != beams.size(), "Lost a batch entry??");
      // determine beam size for next output time step, as max over still-active sentences
      // E.g. if all batch entries are down from beam 5 to no more than 4 surviving hyps, then
      // switch to beam of 4 for all. If all are done, then beam ends up being 0, and we are done.
      size_t localBeamSize = 0; // @TODO: is there some std::algorithm for this?
      for(auto& beam : beams)
        if(beam.size() > localBeamSize)
          localBeamSize = beam.size();

      // done if all batch entries have reached EOS on all beam entries
      if (localBeamSize == 0)
        break;

      //**********************************************************************
      // create constant containing previous path scores for current beam
      // Also create mapping of hyp indices, for reordering the decoder-state tensors.
      std::vector<IndexType> hypIndices; // [localBeamsize, 1, dimBatch, 1] (flattened) tensor index ((beamHypIdx, batchIdx), flattened) of prev hyp that a hyp originated from
      std::vector<Word> prevWords;       // [localBeamsize, 1, dimBatch, 1] (flattened) word that a hyp ended in, for advancing the decoder-model's history
      Expr prevPathScores;               // [localBeamSize, 1, dimBatch, 1], path score that a hyp ended in (last axis will broadcast into vocab size when adding expandedPathScores)
      if(t == 0) { // no scores yet
        prevPathScores = graph->constant({1, 1, 1, 1}, inits::from_value(0));
      } else {
        std::vector<float> prevScores;
        for(size_t beamHypIdx = 0; beamHypIdx < localBeamSize; ++beamHypIdx) {
          for(int batchIdx = 0; batchIdx < dimBatch; ++batchIdx) { // loop over batch entries (active sentences)
            auto& beam = beams[batchIdx];
            if(beamHypIdx < beam.size()) {
              auto hyp = beam[beamHypIdx];
              hypIndices.push_back((IndexType)(hyp->getPrevStateIndex() * dimBatch + batchIdx)); // (beamHypIdx, batchIdx), flattened, for index_select() operation
              prevWords .push_back(hyp->getWord());
              prevScores.push_back(hyp->getPathScore());
            } else {  // pad to localBeamSize (dummy hypothesis)
              hypIndices.push_back(0);
              prevWords.push_back(trgEosId_);  // (unused, but let's use a valid value)
              prevScores.push_back((float)INVALID_PATH_SCORE);
            }
          }
        }
        prevPathScores = graph->constant({(int)localBeamSize, 1, dimBatch, 1}, inits::from_vector(prevScores));
      }

      //**********************************************************************
      // compute expanded path scores with word prediction probs from all scorers
      auto expandedPathScores = prevPathScores; // will become [localBeamSize, 1, dimBatch, dimVocab]
      Expr logProbs;
      for(size_t i = 0; i < scorers_.size(); ++i) {
        // compute output probabilities for current output time step
        //  - uses hypIndices[index in beam, 1, batch index, 1] to reorder scorer state to reflect the top-N in beams[][]
        //  - adds prevWords [index in beam, 1, batch index, 1] to the scorer's target history
        //  - performs one step of the scorer
        //  - returns new NN state for use in next output time step
        //  - returns vector of prediction probabilities over output vocab via newState
        states[i] = scorers_[i]->step(graph, states[i], hypIndices, prevWords, dimBatch, (int)localBeamSize);
        logProbs = states[i]->getLogProbs(); // [localBeamSize, 1, dimBatch, dimVocab]
        // expand all hypotheses, [localBeamSize, 1, dimBatch, 1] -> [localBeamSize, 1, dimBatch, dimVocab]
        expandedPathScores = expandedPathScores + scorers_[i]->getWeight() * logProbs;
      }

      // make beams continuous
      if(dimBatch > 1 && localBeamSize > 1)
        expandedPathScores = swapAxes(expandedPathScores, 0, 2); // -> [dimBatch, 1, localBeamSize, dimVocab]
      else // (avoid copy if we can)
        expandedPathScores = reshape(expandedPathScores, {dimBatch, 1, (int)localBeamSize, expandedPathScores->shape()[-1]});

      // perform NN computation
      if(t == 0)
        graph->forward();
      else
        graph->forwardNext();

      //**********************************************************************
      // suppress specific symbols if not at right positions
      if(trgUnkId_ != -1 && options_->has("allow-unk") && !options_->get<bool>("allow-unk"))
        suppressWord(expandedPathScores, trgUnkId_);
      for(auto state : states)
        state->blacklist(expandedPathScores, batch);

      //**********************************************************************
      // perform beam search

      // find N best amongst the (localBeamSize * dimVocab) hypotheses
      std::vector<unsigned int> nBestKeys; // [dimBatch, localBeamSize] flattened -> (batchIdx, beamHypIdx, word idx) flattened
      std::vector<float> nBestPathScores;  // [dimBatch, localBeamSize] flattened
      getNBestList(/*in*/ expandedPathScores->val(),                           // [dimBatch, 1, localBeamSize, dimVocab or dimShortlist]
                   /*N=*/localBeamSize,
                   /*out*/ nBestPathScores, /*out*/ nBestKeys,
                   /*first=*/t == 0); // @TODO: Why is this passed? To know that the beam size is 1 for first step, for flattened hyp index?
      // Now, nBestPathScores contain N-best expandedPathScores for each batch and beam,
      // and nBestKeys for each their original location (batchIdx, beamHypIdx, word).

      // combine N-best sets with existing search space (beams) to updated search space
      beams = toHyps(nBestKeys, nBestPathScores,
                     /*dimTrgVoc=*/expandedPathScores->shape()[-1],
                     beams,
                     states,           // used for keeping track of per-ensemble-member path score
                     localBeamSize,    // used in the encoding of the (batchIdx, beamHypIdx, word) tuples
                     /*first=*/t == 0, // used to indicate originating beamSize of 1
                     batch);

      // remove all hyps that end in EOS
      // The position of a hyp in the beam may change.
      const auto purgedNewBeams = purgeBeams(beams);

      // add updated search space (beams) to our return value
      bool maxLengthReached = false;
      for(int i = 0; i < dimBatch; ++i) {
        // if this batch entry has surviving hyps then add them to the traceback grid
        if(!beams[i].empty()) {
          if (histories[i]->size() >= options_->get<float>("max-length-factor") * batch->front()->batchWidth())
            maxLengthReached = true;
          histories[i]->add(beams[i], trgEosId_, purgedNewBeams[i].empty() || maxLengthReached);
        }
      }
      if (maxLengthReached) // early exit if max length limit was reached
        break;

      // this is the search space for the next output time step
      beams = purgedNewBeams;
    } // end of main loop over output time steps

    return histories; // [dimBatch][t][N best hyps]
  }
};
}  // namespace marian