Welcome to mirror list, hosted at ThFree Co, Russian Federation.

embedding.h « layers « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: d34c7ffb93506af50f0a289f6d2b0408abe8cc34 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#pragma once
#include "generic.h"
#include "marian.h"

namespace marian {

class FactoredVocab;

// A regular embedding layer.
// Note that this also applies dropout if the option is passed (pass 0 when in inference mode).
// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in
// EncoderDecoderLayerBase, which knows to pass on all required parameters from options.
class Embedding : public LayerBase, public IEmbeddingLayer {
  Expr E_;
  Expr FactorEmbMatrix_; // Factors embedding matrix if combining lemma and factors embeddings with concatenation
  Ptr<FactoredVocab> factoredVocab_;
  Expr multiRows(const Words& data, float dropProb) const;
  Expr embedWithConcat(const Words& data) const;
  bool inference_{false};

public:
  Embedding(Ptr<ExpressionGraph> graph, Ptr<Options> options);

  std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
      Ptr<data::SubBatch> subBatch) const override final;

  Expr apply(const Words& words, const Shape& shape) const override final;

  Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final;
};

class ULREmbedding : public LayerBase, public IEmbeddingLayer {
  std::vector<Expr> ulrEmbeddings_;  // @TODO: These could now better be written as 6 named class members
  bool inference_{false};

public:
  ULREmbedding(Ptr<ExpressionGraph> graph, Ptr<Options> options)
      : LayerBase(graph, options), inference_(opt<bool>("inference")) {
    std::string name = "url_embed";  // opt<std::string>("prefix");
    int dimKeys      = opt<int>("dimTgtVoc");
    int dimQueries   = opt<int>("dimSrcVoc");
    int dimEmb       = opt<int>("dimEmb");
    int dimUlrEmb    = opt<int>("dimUlrEmb");  // ULR mono embed size
    bool fixed       = opt<bool>("fixed", false);

    // Embedding layer initialization should depend only on embedding size, hence fanIn=false
    auto initFunc = inits::glorotUniform(/*fanIn=*/false, /*fanOut=*/true);

    std::string queryFile = opt<std::string>("ulrQueryFile");
    std::string keyFile   = opt<std::string>("ulrKeysFile");
    bool trainTrans       = opt<bool>("ulrTrainTransform", false);
    if(!queryFile.empty() && !keyFile.empty()) {
      initFunc         = inits::fromWord2vec(queryFile, dimQueries, dimUlrEmb, false);
      name             = "ulr_query";
      fixed            = true;
      auto query_embed = graph_->param(name, {dimQueries, dimUlrEmb}, initFunc, fixed);
      ulrEmbeddings_.push_back(query_embed);
      // keys embeds
      initFunc       = inits::fromWord2vec(keyFile, dimKeys, dimUlrEmb, false);
      name           = "ulr_keys";
      fixed          = true;
      auto key_embed = graph_->param(name, {dimKeys, dimUlrEmb}, initFunc, fixed);
      ulrEmbeddings_.push_back(key_embed);
      // actual  trainable embedding
      initFunc = inits::glorotUniform();
      name     = "ulr_embed";
      fixed    = false;
      auto ulr_embed = graph_->param(name, {dimKeys, dimEmb}, initFunc, fixed);  // note the reverse dim
      ulrEmbeddings_.push_back(ulr_embed);
      // init  trainable src embedding
      name               = "ulr_src_embed";
      auto ulr_src_embed = graph_->param(name, {dimQueries, dimEmb}, initFunc, fixed);
      ulrEmbeddings_.push_back(ulr_src_embed);
      // ulr transformation matrix
      // initFunc = inits::eye(1.f); // identity matrix  - is it ok to init wiht identity or shall
      // we make this to the fixed case only
      if(trainTrans) {
        initFunc = inits::glorotUniform();
        fixed    = false;
      } else {
        initFunc = inits::eye();  // identity matrix
        fixed    = true;
      }
      name              = "ulr_transform";
      auto ulrTransform = graph_->param(name, {dimUlrEmb, dimUlrEmb}, initFunc, fixed);
      ulrEmbeddings_.push_back(ulrTransform);

      initFunc = inits::fromValue(
          1.f);  // TBD: we should read sharable flags here - 1 means all sharable - 0 means no
                 // universal embeddings - should be zero for top freq only
      fixed            = true;
      name             = "ulr_shared";
      auto share_embed = graph_->param(name, {dimQueries, 1}, initFunc, fixed);
      ulrEmbeddings_.push_back(share_embed);
    }
  }

  std::tuple<Expr /*embeddings*/, Expr /*mask*/> apply(
      Ptr<data::SubBatch> subBatch) const override final {
    auto queryEmbed   = ulrEmbeddings_[0];  // Q : dimQueries*dimUlrEmb
    auto keyEmbed     = ulrEmbeddings_[1];  // K : dimKeys*dimUlrEmb
    auto uniEmbed     = ulrEmbeddings_[2];  // E : dimQueries*dimEmb
    auto srcEmbed     = ulrEmbeddings_[3];  // I : dimQueries*dimEmb
    auto ulrTransform = ulrEmbeddings_[4];  // A : dimUlrEmb *dimUlrEmb
    auto ulrSharable  = ulrEmbeddings_[5];  // alpha : dimQueries*1
    int dimBatch      = (int)subBatch->batchSize();
    int dimEmb        = uniEmbed->shape()[-1];
    int dimWords      = (int)subBatch->batchWidth();
    // D = K.A.QT
    // dimm(K) = univ_tok_vocab*uni_embed_size
    // dim A = uni_embed_size*uni_embed_size
    // dim Q: uni_embed_size * total_merged_vocab_size
    // dim D = univ_tok_vocab * total_merged_vocab_size
    // note all above can be precombuted and serialized if A is not trainiable and during decoding
    // (TBD) here we need to handle the mini-batch extract raws corresponding to Xs in this
    // minibatch from Q
    auto embIdx          = toWordIndexVector(subBatch->data());
    auto queryEmbeddings = rows(queryEmbed, embIdx);
    auto srcEmbeddings   = rows(srcEmbed, embIdx);     // extract trainable src embeddings
    auto alpha           = rows(ulrSharable, embIdx);  // extract sharable flags
    auto qt              = dot(queryEmbeddings, ulrTransform, false, false);  // A: transform embeddings based on similarity A :  dimUlrEmb*dimUlrEmb
    auto sqrtDim         = std::sqrt((float)queryEmbeddings->shape()[-1]);
    qt = qt / sqrtDim;  // normalize accordin to embed size to avoid dot prodcut growing large in
                        // magnitude with larger embeds sizes
    auto z         = dot(qt, keyEmbed, false, true);                   // query-key similarity
    float dropProb = this->options_->get<float>("ulr-dropout", 0.0f);  // default no dropout
    if(!inference_)
      z = dropout(z, dropProb);

    float tau
        = this->options_->get<float>("ulr-softmax-temperature", 1.0f);  // default no temperature
    // temperature in softmax is to control randomness of predictions
    // high temperature Softmax outputs are more close to each other
    // low temperatures the softmax become more similar to  "hardmax"
    auto weights = softmax(z / tau);  // assume default  is dim=-1, what about temprature? - scaler ??
    auto chosenEmbeddings = dot(weights, uniEmbed);  // AVERAGE
    auto chosenEmbeddings_mix = srcEmbeddings + alpha * chosenEmbeddings;  // this should be elementwise  broadcast
    auto batchEmbeddings = reshape(chosenEmbeddings_mix, {dimWords, dimBatch, dimEmb});
    auto graph           = ulrEmbeddings_.front()->graph();
    auto batchMask = graph->constant({dimWords, dimBatch, 1}, inits::fromVector(subBatch->mask()));
    if(!inference_)
      batchEmbeddings = dropout(batchEmbeddings,
                                options_->get<float>("dropout-embeddings", 0.0f),
                                {batchEmbeddings->shape()[-3], 1, 1});
    return std::make_tuple(batchEmbeddings, batchMask);
  }

  Expr apply(const Words& words, const Shape& shape) const override final {
    return applyIndices(toWordIndexVector(words), shape);
  }

  Expr applyIndices(const std::vector<WordIndex>& embIdx, const Shape& shape) const override final {
    embIdx;
    shape;
    ABORT("not implemented");  // @TODO: implement me
  }
};

}  // namespace marian