Welcome to mirror list, hosted at ThFree Co, Russian Federation.

output.cpp « layers « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: af72b79415ddbd2e24d3b8521df9ca40860e1509 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
#include "output.h"
#include "common/timer.h"
#include "data/factored_vocab.h"
#include "layers/loss.h"

namespace marian {
namespace mlp {

/*private*/ void Output::lazyConstruct(int inputDim) {
  // We must construct lazily since we won't know tying nor input dim in constructor.
  if(Wt_)
    return;

  auto name = options_->get<std::string>("prefix");
  auto numOutputClasses = options_->get<int>("dim");

  factoredVocab_ = FactoredVocab::tryCreateAndLoad(options_->get<std::string>("vocab", ""));
  if(factoredVocab_) {
    numOutputClasses = (int)factoredVocab_->factorVocabSize();
    LOG_ONCE(info, "[embedding] Factored outputs enabled");
  }

  if(tiedParam_) {
    Wt_ = tiedParam_;
  } else {
    if(graph_->get(name + "_W")) {  // support of legacy models that did not transpose
      Wt_ = graph_->param(
          name + "_W", {inputDim, numOutputClasses}, inits::glorotUniform(true, false));
      isLegacyUntransposedW = true;
    } else  // this is the regular case:
      Wt_ = graph_->param(
          name + "_Wt", {numOutputClasses, inputDim}, inits::glorotUniform(false, true));
  }

  if(hasBias_)
    b_ = graph_->param(name + "_b", {1, numOutputClasses}, inits::zeros());

  /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
  ABORT_IF(lemmaDimEmb && !factoredVocab_, "--lemma-dim-emb requires a factored vocabulary");
  if(lemmaDimEmb > 0) {  // > 0 means to embed the (expected) word with a different embedding matrix
#define HARDMAX_HACK
#ifdef HARDMAX_HACK
    lemmaDimEmb = lemmaDimEmb & 0xfffffffe;  // hack to select hard-max: use an odd number
#endif
    auto range = factoredVocab_->getGroupRange(0);
    auto lemmaVocabDim = (int)(range.second - range.first);
    auto initFunc = inits::glorotUniform(
        /*fanIn=*/true, /*fanOut=*/false);  // -> embedding vectors have roughly unit length
    lemmaEt_ = graph_->param(name + "_lemmaEt",
                             {lemmaDimEmb, lemmaVocabDim},
                             initFunc);  // [L x U] L=lemmaDimEmb; transposed for speed
  }
}

Logits Output::applyAsLogits(Expr input) /*override final*/ {
  lazyConstruct(input->shape()[-1]);

  auto affineOrDot = [](Expr x, Expr W, Expr b, bool transA, bool transB) {
    /*
    std::cerr << "affineOrDot.x=" << x->shape() << std::endl;
    std::cerr << "affineOrDot.W=" << W->shape() << std::endl;
    if (b) std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
    std::cerr << "affineOrDot.transA=" << transA << " transB=" << transB << std::endl;
    */
    if(b)
      return affine(x, W, b, transA, transB);
    else
      return dot(x, W, transA, transB);
  };

  auto affineShortlist = [this](Expr x, Expr W, Expr b, bool transA, bool transB) {
    /*    
    std::cerr << "affineShortlist.x=" << x->shape() << std::endl;
    std::cerr << "affineShortlist.W=" << W->shape() << std::endl;
    if (b) std::cerr << "affineShortlist.b=" << b->shape() << std::endl;
    std::cerr << "affineShortlist.transA=" << transA << " transB=" << transB << std::endl;
    */

    Expr ret;

    if (b) {
      // original shortlist. W always has 1 for beam & batch
      ABORT_UNLESS(!shortlist_->isDynamic(), "affineShortlist. Bias not supported with LSH/dynamic shortlist"); // todo rename ABORT_UNLESS to ASSERT
      ret = affine(x, W, b, transA, transB);
    }
    else if (shortlist_->isDynamic()) {
      // LSH produces W entry for each beam and batch => need bdot()
      ABORT_IF(!(!transA && transB), "affineShortlist. Only tested with transA==0 and transB==1");
      ret = bdot(x, W, transA, transB);
    }
    else {
      // original shortlist. W always has 1 for beam & batch
      ret = dot(x, W, transA, transB);
    } 

    //std::cerr << "ret.x=" << ret->shape() << std::endl;
    return ret;
  };

  if(shortlist_) {
    shortlist_->filter(input, Wt_, isLegacyUntransposedW, b_, lemmaEt_);
  }

  if(factoredVocab_) {
    auto graph = input->graph();

    // project each factor separately
    auto numGroups = factoredVocab_->getNumGroups();
    std::vector<Ptr<RationalLoss>> allLogits(numGroups,
                                             nullptr);  // (note: null entries for absent factors)
    Expr input1 = input;                                // [B... x D]
    Expr Plemma = nullptr;                              // used for lemmaDimEmb=-1
    Expr inputLemma = nullptr;                          // used for lemmaDimEmb=-2, -3
    for(size_t g = 0; g < numGroups; g++) {
      auto range = factoredVocab_->getGroupRange(g);
      if(g > 0 && range.first == range.second)  // empty entry
        continue;
      ABORT_IF(g > 0 && range.first != factoredVocab_->getGroupRange(g - 1).second,
               "Factor groups must be consecutive (group {} vs predecessor)",
               g);
      // slice this group's section out of W_
      Expr factorWt, factorB;
      if(g == 0 && shortlist_) {
        factorWt = shortlist_->getCachedShortWt();
        factorB = shortlist_->getCachedShortb();
      } else {
        factorWt = slice(
            Wt_, isLegacyUntransposedW ? -1 : 0, Slice((int)range.first, (int)range.second));
        if(hasBias_)
          factorB = slice(b_, -1, Slice((int)range.first, (int)range.second));
      }
      /*const*/ int lemmaDimEmb = options_->get<int>("lemma-dim-emb", 0);
      if((lemmaDimEmb == -2 || lemmaDimEmb == -3)
         && g > 0) {  // -2/-3 means a gated transformer-like structure (-3 = hard-max)
        LOG_ONCE(info, "[embedding] using lemma conditioning with gate");
        // this mimics one transformer layer
        //  - attention over two inputs:
        //     - e = current lemma. We use the original embedding vector; specifically, expectation
        //     over all lemmas.
        //     - input = hidden state FF(h_enc+h_dec)
        //  - dot-prod attention to allow both sides to influence (unlike our recurrent
        //  self-attention)
        //  - multi-head to allow for multiple conditions to be modeled
        //  - add & norm, for gradient flow and scaling
        //  - FF layer   --this is expensive; it is per-factor
        // multi-head attention
        int inputDim = input->shape()[-1];
        int heads = 8;
        auto name = options_->get<std::string>("prefix") + "_factor" + std::to_string(g);
        auto Wq = graph_->param(name + "_Wq", {inputDim, inputDim}, inits::glorotUniform());
        auto Wk = graph_->param(name + "_Wk", {inputDim, inputDim}, inits::glorotUniform());
        auto Wv = graph_->param(name + "_Wv", {inputDim, inputDim}, inits::glorotUniform());
        auto toMultiHead = [&](Expr x, int heads) {
          const auto& shape = x->shape();
          int inputDim = shape[-1];
          int otherDim = shape.elements() / inputDim;
          ABORT_IF(inputDim / heads * heads != inputDim,
                   "inputDim ({}) must be multiple of number of heads ({})",
                   inputDim,
                   heads);
          return reshape(x, {otherDim, heads, 1, inputDim / heads});
        };
        input1 = inputLemma;
        auto qm = toMultiHead(dot(input1, Wq), heads);  // [B... x H x D/H] projected query
        auto kdm = toMultiHead(dot(input1 - input, Wk),
                               heads);  // [B... x H x D/H] the two data vectors projected as keys.
                                        // Use diff and sigmoid, instead of softmax.
        auto vem = toMultiHead(
            dot(input1, Wv),
            heads);  // [B... x H x D/H] one of the two data vectors projected as values
        auto vim = toMultiHead(dot(input, Wv), heads);  // [B... x H x D/H] the other
        auto zm = bdot(qm, kdm, false, true);           // [B... x H x 1]
        auto sm = sigmoid(zm);                          // [B... x H x 1]
        auto rm = sm * (vem - vim) + vim;               // [B... x H x D/H]
        auto r = reshape(rm, input->shape());           // [B... x D]
        // add & norm
        input1 = r + input1;
        input1 = layerNorm(input1, name + "_att");
        // FF layer
        auto ffnDropProb = 0.1f;     // @TODO: get as a parameter
        auto ffnDim = inputDim * 2;  // @TODO: get as a parameter
        auto f = denseInline(input1,
                             name + "_ffn",
                             /*suffix=*/"1",
                             ffnDim,
                             inits::glorotUniform(),
                             "relu",
                             ffnDropProb);
        f = denseInline(f, name + "_ffn", /*suffix=*/"2", inputDim);
        // add & norm
        input1 = f + input1;
        input1 = layerNorm(input1, name + "_ffn");
      }
      // @TODO: b_ should be a vector, not a matrix; but shotlists use cols() in, which requires a
      // matrix
      Expr factorLogits;
      if(g == 0 && shortlist_) {
        Expr tmp = transpose(input1, {0, 2, 1, 3});
        factorLogits = affineShortlist(
            tmp,
            factorWt,
            factorB,
            false,
            /*transB=*/isLegacyUntransposedW ? false : true);  // [B... x U] factor logits
        factorLogits = transpose(factorLogits, {0, 2, 1, 3});
      }
      else {
        factorLogits = affineOrDot(
            input1,
            factorWt,
            factorB,
            false,
            /*transB=*/isLegacyUntransposedW ? false : true);  // [B... x U] factor logits
      }

      // optionally add lemma-dependent bias
      if(Plemma) {  // [B... x U0]
        int lemmaVocabDim = Plemma->shape()[-1];
        int factorVocabDim = factorLogits->shape()[-1];
        auto name = options_->get<std::string>("prefix");
        Expr lemmaBt
            = graph_->param(name + "_lemmaBt_" + std::to_string(g),
                            {factorVocabDim, lemmaVocabDim},
                            inits::zeros());  // [U x U0] U0=#lemmas one bias per class per lemma
        auto b = dot(Plemma, lemmaBt, false, true);  // [B... x U]
        factorLogits = factorLogits + b;
      }
      //std::cerr << "factorLogits=" << factorLogits->shape() << std::endl;
      allLogits[g] = New<RationalLoss>(factorLogits, nullptr);
      // optionally add a soft embedding of lemma back to create some lemma dependency
      // @TODO: if this works, move it into lazyConstruct
      if(lemmaDimEmb == -2 && g == 0) {  // -2 means a gated transformer-like structure
        LOG_ONCE(info, "[embedding] using lemma conditioning with gate, soft-max version");
        // get expected lemma embedding vector
        auto factorLogSoftmax = logsoftmax(
            factorLogits);  // [B... x U] note: with shortlist, this is not the full lemma set
        auto factorSoftmax = exp(factorLogSoftmax);
        inputLemma = dot(factorSoftmax,
                         factorWt,
                         false,
                         /*transB=*/isLegacyUntransposedW ? true : false);  // [B... x D]
      } else if(lemmaDimEmb == -3 && g == 0) {  // same as -2 except with hard max
        LOG_ONCE(info, "[embedding] using lemma conditioning with gate, hard-max version");
        // get max-lemma embedding vector
        auto maxVal = max(factorLogits,
                          -1);  // [B... x U] note: with shortlist, this is not the full lemma set
        auto factorHardmax = eq(factorLogits, maxVal);
        inputLemma = dot(factorHardmax,
                         factorWt,
                         false,
                         /*transB=*/isLegacyUntransposedW ? true : false);  // [B... x D]
      } else if(lemmaDimEmb == -1 && g == 0) {  // -1 means learn a lemma-dependent bias
        ABORT_IF(shortlist_, "Lemma-dependent bias with short list is not yet implemented");
        LOG_ONCE(info, "[embedding] using lemma-dependent bias");
        auto factorLogSoftmax
            = logsoftmax(factorLogits);  // (we do that again later, CSE will kick in)
        auto z = /*stopGradient*/ (factorLogSoftmax);
        Plemma = exp(z);                      // [B... x U]
      } else if(lemmaDimEmb > 0 && g == 0) {  // > 0 means learn a re-embedding matrix
        LOG_ONCE(info, "[embedding] enabled re-embedding of lemma, at dim {}", lemmaDimEmb);
        // compute softmax. We compute logsoftmax() separately because this way, computation will be
        // reused later via CSE
        auto factorLogSoftmax = logsoftmax(factorLogits);
        auto factorSoftmax = exp(factorLogSoftmax);
#ifdef HARDMAX_HACK
        bool hardmax = (lemmaDimEmb & 1)
                       != 0;  // odd value triggers hardmax for now (for quick experimentation)
        if(hardmax) {
          lemmaDimEmb = lemmaDimEmb & 0xfffffffe;
          LOG_ONCE(info, "[embedding] HARDMAX_HACK enabled. Actual dim is {}", lemmaDimEmb);
          auto maxVal = max(factorSoftmax, -1);
          factorSoftmax = eq(factorSoftmax, maxVal);
        }
#endif
        // re-embedding lookup, soft-indexed by softmax
        Expr e;
        if(shortlist_) {  // short-listed version of re-embedding matrix
          Expr cachedShortLemmaEt = shortlist_->getCachedShortLemmaEt();
          // std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
          // std::cerr << "cachedShortLemmaEt=" << cachedShortLemmaEt->shape() << std::endl;
          const Shape &fShape = factorSoftmax->shape();
          ABORT_IF(fShape[1] != 1, "We are decoding with a shortlist but time step size {} != 1??", fShape[1]);
          factorSoftmax = reshape(factorSoftmax, {fShape[0], fShape[2], 1, fShape[3]}); // we can switch dims because time step is of size 1
          // std::cerr << "factorSoftmax=" << factorSoftmax->shape() << std::endl;
          e = bdot(factorSoftmax, cachedShortLemmaEt, false, true);
          // std::cerr << "e.1=" << e->shape() << std::endl;
          const Shape &eShape = e->shape();
          e = reshape(e, {eShape[0], 1, eShape[1], eShape[3]}); // switch dims back, again possible because time step is of size 1
          // std::cerr << "e.2=" << e->shape() << std::endl;
          // std::cerr << std::endl;
        } else { // for scoring, training and decoding without a shortlist we use a simple dot operation
          e = dot(factorSoftmax,
                  lemmaEt_,
                  false,
                  true);  // [B... x L]
        }

        // project it back to regular hidden dim
        int inputDim = input1->shape()[-1];
        auto name = options_->get<std::string>("prefix");
        // note: if the lemmaEt[:,w] have unit length (var = 1/L), then lemmaWt @ lemmaEt is also
        // length 1
        Expr lemmaWt
            = inputDim == lemmaDimEmb
                  ? nullptr
                  : graph_->param(name + "_lemmaWt",
                                  {inputDim, lemmaDimEmb},
                                  inits::glorotUniform());    // [D x L] D=hidden-vector dimension
        auto f = lemmaWt ? dot(e, lemmaWt, false, true) : e;  // [B... x D]
        // augment the original hidden vector with this additional information
        input1 = input1 + f;
      }
    }
    return Logits(std::move(allLogits), factoredVocab_);
  } else if(shortlist_) {
    const Shape &inputShape = input->shape();
    assert(inputShape[1] == 1); // time dimension always 1 for decoding
    input = reshape(input, {inputShape[0], inputShape[2], 1, inputShape[3]});

    Expr Wt = shortlist_->getCachedShortWt();
    Expr b = shortlist_->getCachedShortb();
    Expr ret = affineShortlist(input,
                              Wt,
                              b,
                              false,
                              /*transB=*/isLegacyUntransposedW ? false : true);
    const Shape &retShape = ret->shape();
    assert(retShape[2] == 1); // time dimension always 1 for decoding
    ret = reshape(ret, {retShape[0], 1, retShape[1], retShape[3]});
    return Logits(ret);
  } else {
    Expr ret = affineOrDot(input, Wt_, b_, false, /*transB=*/isLegacyUntransposedW ? false : true);
    return Logits(ret);
  }
}

}  // namespace mlp
}  // namespace marian