Welcome to mirror list, hosted at ThFree Co, Russian Federation.

bert.h « models « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 514274572c8637f7c13bcf91ace4466ec26d653b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
#pragma once

#include "data/corpus_base.h"
#include "models/encoder_classifier.h"
#include "models/transformer.h"   // @BUGBUG: transformer.h is large and was meant to be compiled separately
#include "data/rng_engine.h"

namespace marian {

/**
 * This file contains nearly all BERT-related code and adds BERT-funtionality
 * on top of existing classes like TansformerEncoder and Classifier.
 */

namespace data {

/**
 * BERT-specific mini-batch that computes masking for Masked LM training.
 * Expects symbols [MASK], [SEP], [CLS] to be present in vocabularies unless
 * other symbols are specified in the config.
 *
 * This takes a normal CorpusBatch and extends it with additional data. Luckily
 * all the BERT-functionality can be inferred from a CorpusBatch alone.
 */
class BertBatch : public CorpusBatch {
private:
  std::vector<IndexType> maskedPositions_;
  Words maskedWords_;
  std::vector<IndexType> sentenceIndices_;

  std::string maskSymbol_;
  std::string sepSymbol_;
  std::string clsSymbol_;

  // Selects a random word from the vocabulary
  std::unique_ptr<std::uniform_int_distribution<WordIndex>> randomWord_;

  // Selects a random integer between 0 and 99
  std::unique_ptr<std::uniform_real_distribution<float>> randomPercent_;

  // Word ids of words that should not be masked, e.g. separators, padding
  std::unordered_set<Word> dontMask_;

  // Masking function, i.e. replaces a chosen word with either
  // a [MASK] symbol, itself or a random word
  Word maskOut(Word word, Word mask, std::mt19937& engine) {
    auto subBatch = subBatches_.front();

    // @TODO: turn those threshold into parameters, adjustable from command line
    float r = (*randomPercent_)(engine);
    if (r < 0.1f) { // for 10% of cases return same word
      return word;
    } else if (r < 0.2f) { // for 10% return random word
      Word randWord = Word::fromWordIndex((*randomWord_)(engine));
      if(dontMask_.count(randWord) > 0) // some words, e.g. [CLS] or </s>, may not be used as random words
        return mask;                    // for those, return the mask symbol instead
      else
        return randWord;                // else return the random word
    } else { // for 80% of words apply mask symbol
      return mask;
    }
  }

public:

  // Takes a corpus batch, random engine (for deterministic behavior) and the masking percentage.
  // Also sets special vocabulary items given on command line.
  BertBatch(Ptr<CorpusBatch> batch,
            std::mt19937& engine,
            float maskFraction,
            const std::string& maskSymbol,
            const std::string& sepSymbol,
            const std::string& clsSymbol,
            int dimTypeVocab)
    : CorpusBatch(*batch),
      maskSymbol_(maskSymbol), sepSymbol_(sepSymbol), clsSymbol_(clsSymbol) {

    // BERT expects a textual first stream and a second stream with class labels
    auto subBatch = subBatches_.front();
    const auto& vocab = *subBatch->vocab();

    // Initialize to sample random vocab id
    randomWord_.reset(new std::uniform_int_distribution<WordIndex>(0, (WordIndex)vocab.size()));

    // Intialize to sample random percentage
    randomPercent_.reset(new std::uniform_real_distribution<float>(0.f, 1.f));

    auto& words = subBatch->data();

    // Get word id of special symbols
    Word maskId  = vocab[maskSymbol_];
    Word clsId   = vocab[clsSymbol_];
    Word sepId   = vocab[sepSymbol_];

    ABORT_IF(maskId == vocab.getUnkId(),
             "BERT masking symbol {} not found in vocabulary", maskSymbol_);

    ABORT_IF(sepId == vocab.getUnkId(),
             "BERT separator symbol {} not found in vocabulary", sepSymbol_);

    ABORT_IF(clsId == vocab.getUnkId(),
             "BERT class symbol {} not found in vocabulary", clsSymbol_);

    dontMask_.insert(clsId); // don't mask class token
    dontMask_.insert(sepId); // don't mask separator token
    dontMask_.insert(vocab.getEosId()); // don't mask </s>
    // it's ok to mask <unk>

    std::vector<int> selected;
    selected.reserve(words.size());
    for(int i = 0; i < words.size(); ++i) // collect words among which we will mask
      if(dontMask_.count(words[i]) == 0)  // do not add indices of special words
        selected.push_back(i);
    std::shuffle(selected.begin(), selected.end(), engine); // randomize positions
    selected.resize((size_t)std::ceil(selected.size() * maskFraction)); // select first x percent from shuffled indices

    for(int i : selected) {
      maskedPositions_.push_back(i);                // where is the original word?
      maskedWords_.push_back(words[i]);             // what is the original word?
      words[i] = maskOut(words[i], maskId, engine); // mask that position
    }

    annotateSentenceIndices(dimTypeVocab);
  }

  BertBatch(Ptr<CorpusBatch> batch,
            const std::string& sepSymbol,
            const std::string& clsSymbol,
            int dimTypeVocab)
    : CorpusBatch(*batch),
      maskSymbol_("dummy"), sepSymbol_(sepSymbol), clsSymbol_(clsSymbol) {
    annotateSentenceIndices(dimTypeVocab);
  }

  void annotateSentenceIndices(int dimTypeVocab) {
    // BERT expects a textual first stream and a second stream with class labels
    auto subBatch = subBatches_.front();
    const auto& vocab = *subBatch->vocab();
    auto& words = subBatch->data();

    // Get word id of special symbols
    Word sepId   = vocab[sepSymbol_];
    ABORT_IF(sepId == vocab.getUnkId(),
             "BERT separator symbol {} not found in vocabulary", sepSymbol_);

    int dimBatch = (int)subBatch->batchSize();
    int dimWords = (int)subBatch->batchWidth();

    const size_t maxSentPos = dimTypeVocab;

    // create indices for BERT sentence embeddings A and B
    sentenceIndices_.resize(words.size()); // each word is either in sentence A or B
    std::vector<IndexType> sentPos(dimBatch, 0); // initialize each batch entry with being A [0]
    for(int i = 0; i < dimWords; ++i) {   // advance word-wise
      for(int j = 0; j < dimBatch; ++j) { // scan batch-wise
        int k = i * dimBatch + j;
        sentenceIndices_[k] = sentPos[j]; // set to current sentence position for batch entry, max position 1.
        if(words[k] == sepId && sentPos[j] < maxSentPos) { // if current word is a separator and not beyond range
          sentPos[j]++;                   // then increase sentence position for batch entry (to B [1])
        }
      }
    }
  }

  const std::vector<IndexType>& bertMaskedPositions() { return maskedPositions_; }
  const Words& bertMaskedWords() { return maskedWords_; }
  const std::vector<IndexType>& bertSentenceIndices() { return sentenceIndices_; }
};

}

/**
 * BERT-specific version of EncoderClassifier, mostly here to automatically convert a
 * CorpusBatch to BertBatch.
 */
class BertEncoderClassifier : public EncoderClassifier, public data::RNGEngine { // @TODO: this random engine is not being serialized right now
public:
  BertEncoderClassifier(Ptr<Options> options)
  : EncoderClassifier(options) {}

  std::vector<Ptr<ClassifierState>> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, bool clearGraph) override {
    std::string modelType = opt<std::string>("type");
    int dimTypeVocab = opt<int>("bert-type-vocab-size");

    // intercept batch and annotate with BERT-specific concepts
    Ptr<data::BertBatch> bertBatch;
    if(modelType == "bert") { // full BERT pre-training
      bertBatch = New<data::BertBatch>(batch,
                                       eng_,
                                       opt<float>("bert-masking-fraction", 0.15f), // 15% by default according to paper
                                       opt<std::string>("bert-mask-symbol"),
                                       opt<std::string>("bert-sep-symbol"),
                                       opt<std::string>("bert-class-symbol"),
                                       dimTypeVocab);
    } else if(modelType == "bert-classifier") { // we are probably fine-tuning a BERT model for a classification task
      bertBatch = New<data::BertBatch>(batch,
                                       opt<std::string>("bert-sep-symbol"),
                                       opt<std::string>("bert-class-symbol"),
                                       dimTypeVocab); // only annotate sentence separators
    } else {
      ABORT("Unknown BERT-style model: {}", modelType);
    }

    return EncoderClassifier::apply(graph, bertBatch, clearGraph);
  }

  // for externally created BertBatch for instance in BertValidator
  std::vector<Ptr<ClassifierState>> apply(Ptr<ExpressionGraph> graph, Ptr<data::BertBatch> bertBatch, bool clearGraph) {
    return EncoderClassifier::apply(graph, bertBatch, clearGraph);
  }
};

/**
 * BERT-specific modifications to EncoderTransformer
 * Actually all that is needed is to intercept the creation of special embeddings,
 * here sentence embeddings for sentence A and B.
 * @BUGBUG: transformer.h was meant to be compiled separately. I.e., one cannot derive from it.
 *          Is there a way to maybe instead include a reference in here, instead of deriving from it?
 */
class BertEncoder : public EncoderTransformer {
  using EncoderTransformer::EncoderTransformer;
public:
  Expr addSentenceEmbeddings(Expr embeddings,
                             Ptr<data::CorpusBatch> batch,
                             bool learnedPosEmbeddings) const {
    Ptr<data::BertBatch> bertBatch = std::dynamic_pointer_cast<data::BertBatch>(batch);
    ABORT_IF(!bertBatch, "Batch must be BertBatch for BERT training or fine-tuning");

    int dimEmb = embeddings->shape()[-1];
    int dimBatch = embeddings->shape()[-2];
    int dimWords = embeddings->shape()[-3];

    int dimTypeVocab = opt<int>("bert-type-vocab-size", 2);

    Expr signal;
    if(learnedPosEmbeddings) {
      auto sentenceEmbeddings = embedding()
                               ("prefix", "Wtype")
                               ("dimVocab", dimTypeVocab) // sentence A or sentence B
                               ("dimEmb", dimEmb)
                               .construct(graph_);
      signal = sentenceEmbeddings->applyIndices(bertBatch->bertSentenceIndices(), {dimWords, dimBatch, dimEmb});
    } else {
      // @TODO: factory for positional embeddings?
      // constant sinusoidal position embeddings, no backprob
      auto sentenceEmbeddingsExpr = graph_->constant({2, dimEmb}, inits::sinusoidalPositionEmbeddings(0));
      signal = rows(sentenceEmbeddingsExpr, bertBatch->bertSentenceIndices());
      signal = reshape(signal, {dimWords, dimBatch, dimEmb});
    }

    return embeddings + signal;
  }

  virtual Expr addSpecialEmbeddings(Expr input, int start = 0, Ptr<data::CorpusBatch> batch = nullptr) const override {
    bool trainPosEmbeddings = opt<bool>("transformer-train-position-embeddings", true);
    bool trainTypeEmbeddings = opt<bool>("bert-train-type-embeddings", true);
    input = addPositionalEmbeddings(input, start, trainPosEmbeddings);
    input = addSentenceEmbeddings(input, batch, trainTypeEmbeddings);
    return input;
  }
};

/**
 * BERT-specific classifier
 * Can be used for next sentence prediction task or other fine-tuned down-stream tasks
 * Does not actually need a BertBatch, works with CorpusBatch.
 *
 * @TODO: This is in fact so generic that we might move it out of here as the typical classifier implementation
 */
class BertClassifier : public ClassifierBase {
  using ClassifierBase::ClassifierBase;
public:
  Ptr<ClassifierState> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
    ABORT_IF(encoderStates.size() != 1, "Currently we only support a single encoder BERT model");

    auto context = encoderStates[0]->getContext();
    auto classEmbeddings = slice(context, /*axis=*/-3, /*i=*/0); // [CLS] symbol is first symbol in each sequence

    int dimModel = classEmbeddings->shape()[-1];
    int dimTrgCls = opt<std::vector<int>>("dim-vocabs")[batchIndex_]; // Target vocab is used as class labels

    auto output = mlp::mlp()                                          //
                    .push_back(mlp::dense()                           //
                                 ("prefix", prefix_ + "_ff_logit_l1") //
                                 ("dim", dimModel)                    //
                                 ("activation", (int)mlp::act::tanh))      // @TODO: do we actually need this?
                    .push_back(mlp::output()                          //
                                 ("dim", dimTrgCls))                  //
                                 ("prefix", prefix_ + "_ff_logit_l2") //
                    .construct(graph);

    auto logits = output->apply(classEmbeddings); // class logits for each batch entry

    auto state = New<ClassifierState>();
    state->setLogProbs(logits);

    // Filled externally, for BERT these are NextSentence prediction labels
    const auto& classLabels = (*batch)[batchIndex_]->data();
    state->setTargetWords(classLabels);

    return state;
  }

  virtual void clear() override {}
};

/**
 * This is a model that pretrains BERT for classification.
 * This is also a Classifier, but compared to the BertClassifier above needs the BERT-specific information from BertBatch
 * as this is self-generating its labels from the source. Labels are dynamically created as complements of the masking process.
 */
class BertMaskedLM : public ClassifierBase {
  using ClassifierBase::ClassifierBase;
public:
  Ptr<ClassifierState> apply(Ptr<ExpressionGraph> graph, Ptr<data::CorpusBatch> batch, const std::vector<Ptr<EncoderState>>& encoderStates) override {
    Ptr<data::BertBatch> bertBatch = std::dynamic_pointer_cast<data::BertBatch>(batch);

    ABORT_IF(!bertBatch, "Batch must be BertBatch for BERT training");
    ABORT_IF(encoderStates.size() != 1, "Currently we only support a single encoder BERT model");

    auto context = encoderStates[0]->getContext();

    auto bertMaskedPositions    = graph->indices(bertBatch->bertMaskedPositions()); // positions in batch of masked entries
    const auto& bertMaskedWords = bertBatch->bertMaskedWords();   // vocab ids of entries that have been masked

    int dimModel = context->shape()[-1];
    int dimBatch = context->shape()[-2];
    int dimTime  = context->shape()[-3];

    auto maskedContext = rows(reshape(context, {dimBatch * dimTime, dimModel}), bertMaskedPositions); // subselect stuff that has actually been masked out

    int dimVoc = opt<std::vector<int>>("dim-vocabs")[batchIndex_];

    auto layer1 = mlp::mlp()
      .push_back(mlp::dense()
                 ("prefix", prefix_ + "_ff_logit_l1")
                 ("dim", dimModel))
                 .construct(graph);

    auto intermediate = layer1->apply(maskedContext);

    std::string activationType = opt<std::string>("transformer-ffn-activation");
    if(activationType == "relu")
      intermediate = relu(intermediate);
    else if(activationType == "swish")
      intermediate = swish(intermediate);
    else if(activationType == "gelu")
      intermediate = gelu(intermediate);
    else
      ABORT("Activation function {} not supported in BERT masked LM", activationType);

    auto gamma = graph->param(prefix_ + "_ff_ln_scale", {1, dimModel}, inits::ones());
    auto beta  = graph->param(prefix_ + "_ff_ln_bias",  {1, dimModel}, inits::zeros());
    intermediate = layerNorm(intermediate, gamma, beta);

    auto layer2 = mlp::mlp()
      .push_back(mlp::output(
                  "prefix", prefix_ + "_ff_logit_l2",
                  "dim", dimVoc)
                 .tieTransposed("Wemb"))
      .construct(graph);

    auto logits = layer2->apply(intermediate); // [-4: beam depth=1, -3: max length, -2: batch size, -1: vocab dim]

    auto state = New<ClassifierState>();
    state->setLogProbs(logits);
    state->setTargetWords(bertMaskedWords);

    return state;
  }

  virtual void clear() override {}
};

}