Welcome to mirror list, hosted at ThFree Co, Russian Federation.

model.h « dl4mt « gpu « amun « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 0829d233f43e74f43c1ed608febd47c9c29d7ede (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
#pragma once

#include <map>
#include <string>
#include <yaml-cpp/yaml.h>

#include "gpu/mblas/tensor.h"
#include "gpu/npz_converter.h"

namespace amunmt {
namespace GPU {

struct Weights {

  //////////////////////////////////////////////////////////////////////////////
  struct EncEmbeddings {
    EncEmbeddings(const EncEmbeddings&) = delete;
    EncEmbeddings(const NpzConverter& model);

    // Embedding matrices for word factors. The first factor is the word
    // surface form. The rest are optional.
    std::vector<std::shared_ptr<mblas::Tensor>> Es_;
  };

  ////////////////////////////////////////////////////////////////////////////////////////////////////
  struct EncForwardGRU {
    EncForwardGRU(const EncForwardGRU&) = delete;
    EncForwardGRU(const NpzConverter& model);

    const std::shared_ptr<mblas::Tensor> W_;
    const std::shared_ptr<mblas::Tensor> B_;
    const std::shared_ptr<mblas::Tensor> U_;
    const std::shared_ptr<mblas::Tensor> Wx_;
    const std::shared_ptr<mblas::Tensor> Bx1_;
    const std::shared_ptr<mblas::Tensor> Bx2_;
    const std::shared_ptr<mblas::Tensor> Ux_;
    const std::shared_ptr<mblas::Tensor> Gamma_1_;
    const std::shared_ptr<mblas::Tensor> Gamma_2_;
  };

  ////////////////////////////////////////////////////////////////////////////////////////////////////
  struct EncBackwardGRU {
    EncBackwardGRU(const EncBackwardGRU&) = delete;
    EncBackwardGRU(const NpzConverter& model);

    const std::shared_ptr<mblas::Tensor> W_;
    const std::shared_ptr<mblas::Tensor> B_;
    const std::shared_ptr<mblas::Tensor> U_;
    const std::shared_ptr<mblas::Tensor> Wx_;
    const std::shared_ptr<mblas::Tensor> Bx1_;
    const std::shared_ptr<mblas::Tensor> Bx2_;
    const std::shared_ptr<mblas::Tensor> Ux_;
    const std::shared_ptr<mblas::Tensor> Gamma_1_;
    const std::shared_ptr<mblas::Tensor> Gamma_2_;
  };

  ////////////////////////////////////////////////////////////////////////////////////////////////////
  struct EncForwardLSTM {
    EncForwardLSTM(const EncForwardLSTM&) = delete;
    EncForwardLSTM(const NpzConverter& model);

    const std::shared_ptr<mblas::Tensor> W_;
    const std::shared_ptr<mblas::Tensor> B_;
    const std::shared_ptr<mblas::Tensor> U_;
    const std::shared_ptr<mblas::Tensor> Wx_;
    const std::shared_ptr<mblas::Tensor> Bx_;
    const std::shared_ptr<mblas::Tensor> Ux_;
    const std::shared_ptr<mblas::Tensor> Gamma_1_;
    const std::shared_ptr<mblas::Tensor> Gamma_2_;
  };

  ////////////////////////////////////////////////////////////////////////////////////////////////////
  struct EncBackwardLSTM {
    EncBackwardLSTM(const EncBackwardLSTM&) = delete;
    EncBackwardLSTM(const NpzConverter& model);

    const std::shared_ptr<mblas::Tensor> W_;
    const std::shared_ptr<mblas::Tensor> B_;
    const std::shared_ptr<mblas::Tensor> U_;
    const std::shared_ptr<mblas::Tensor> Wx_;
    const std::shared_ptr<mblas::Tensor> Bx_;
    const std::shared_ptr<mblas::Tensor> Ux_;
    const std::shared_ptr<mblas::Tensor> Gamma_1_;
    const std::shared_ptr<mblas::Tensor> Gamma_2_;
  };

  ////////////////////////////////////////////////////////////////////////////////////////////////////

  struct DecEmbeddings {
    DecEmbeddings(const DecEmbeddings&) = delete;
    DecEmbeddings(const NpzConverter& model);

    const std::shared_ptr<mblas::Tensor> E_;
  };

  ////////////////////////////////////////////////////////////////////////////////////////////////////

  struct DecInit {
    DecInit(const DecInit&) = delete;
    DecInit(const NpzConverter& model);

    const std::shared_ptr<mblas::Tensor> Wi_;
    const std::shared_ptr<mblas::Tensor> Bi_;
    const std::shared_ptr<mblas::Tensor> Gamma_;
  };

  ////////////////////////////////////////////////////////////////////////////////////////////////////

  struct DecGRU1 {
    DecGRU1(const DecGRU1&) = delete;
    DecGRU1(const NpzConverter& model);

    const std::shared_ptr<mblas::Tensor> W_;
    const std::shared_ptr<mblas::Tensor> B_;
    const std::shared_ptr<mblas::Tensor> U_;
    const std::shared_ptr<mblas::Tensor> Wx_;
    const std::shared_ptr<mblas::Tensor> Bx1_;
    const std::shared_ptr<mblas::Tensor> Bx2_;
    const std::shared_ptr<mblas::Tensor> Ux_;
    const std::shared_ptr<mblas::Tensor> Gamma_1_;
    const std::shared_ptr<mblas::Tensor> Gamma_2_;
  };

  ////////////////////////////////////////////////////////////////////////////////////////////////////

  struct DecGRU2 {
    DecGRU2(const DecGRU2&) = delete;
    DecGRU2(const NpzConverter& model);

    const std::shared_ptr<mblas::Tensor> W_;
    const std::shared_ptr<mblas::Tensor> B_;
    const std::shared_ptr<mblas::Tensor> U_;
    const std::shared_ptr<mblas::Tensor> Wx_;
    const std::shared_ptr<mblas::Tensor> Bx2_;
    const std::shared_ptr<mblas::Tensor> Bx1_;
    const std::shared_ptr<mblas::Tensor> Ux_;
    const std::shared_ptr<mblas::Tensor> Gamma_1_;
    const std::shared_ptr<mblas::Tensor> Gamma_2_;
  };

  ////////////////////////////////////////////////////////////////////////////////////////////////////

  struct DecLSTM1 {
    DecLSTM1(const DecLSTM1&) = delete;
    DecLSTM1(const NpzConverter& model);

    const std::shared_ptr<mblas::Tensor> W_;
    const std::shared_ptr<mblas::Tensor> B_;
    const std::shared_ptr<mblas::Tensor> U_;
    const std::shared_ptr<mblas::Tensor> Wx_;
    const std::shared_ptr<mblas::Tensor> Bx_;
    const std::shared_ptr<mblas::Tensor> Ux_;
    const std::shared_ptr<mblas::Tensor> Gamma_1_;
    const std::shared_ptr<mblas::Tensor> Gamma_2_;
  };

  ////////////////////////////////////////////////////////////////////////////////////////////////////

  struct DecLSTM2 {
    DecLSTM2(const DecLSTM2&) = delete;
    DecLSTM2(const NpzConverter& model);

    const std::shared_ptr<mblas::Tensor> W_;
    const std::shared_ptr<mblas::Tensor> B_;
    const std::shared_ptr<mblas::Tensor> U_;
    const std::shared_ptr<mblas::Tensor> Wx_;
    const std::shared_ptr<mblas::Tensor> Bx_;
    const std::shared_ptr<mblas::Tensor> Ux_;
    const std::shared_ptr<mblas::Tensor> Gamma_1_;
    const std::shared_ptr<mblas::Tensor> Gamma_2_;
  };

  ////////////////////////////////////////////////////////////////////////////////////////////////////

  // A wrapper class to deserialize weights for multiplicative-LSTM,
  // multiplicative-GRU and such
  template<class BaseWeights>
  struct MultWeights: public BaseWeights {
    MultWeights(const MultWeights&) = delete;
    MultWeights(const NpzConverter& model, const std::string& prefix)
      : BaseWeights(model),
      Wm_(model.get(p(prefix, "Wm"), true)),
      Bm_(model.get(p(prefix, "bm"), true, true)),
      Um_(model.get(p(prefix, "Um"), true)),
      Bmu_(model.get(p(prefix, "bmu"), true, true))
      {}
    const std::shared_ptr<mblas::Tensor> Wm_;
    const std::shared_ptr<mblas::Tensor> Bm_;
    const std::shared_ptr<mblas::Tensor> Um_;
    const std::shared_ptr<mblas::Tensor> Bmu_;
  private:
    std::string p(std::string prefix, std::string sufix){
      return prefix + "_" + sufix;
    }
  };

  ////////////////////////////////////////////////////////////////////////////////////////////////////

  struct DecAlignment {
    DecAlignment(const DecAlignment&) = delete;
    DecAlignment(const NpzConverter& model);

    const std::shared_ptr<mblas::Tensor> V_;
    const std::shared_ptr<mblas::Tensor> W_;
    const std::shared_ptr<mblas::Tensor> B_;
    const std::shared_ptr<mblas::Tensor> U_;
    const std::shared_ptr<mblas::Tensor> C_;
    const std::shared_ptr<mblas::Tensor> Gamma_1_;
    const std::shared_ptr<mblas::Tensor> Gamma_2_;
  };

  ////////////////////////////////////////////////////////////////////////////////////////////////////

  struct DecSoftmax {
    DecSoftmax(const DecSoftmax&) = delete;
    DecSoftmax(const NpzConverter& model);

    const std::shared_ptr<mblas::Tensor> W1_;
    const std::shared_ptr<mblas::Tensor> B1_;
    const std::shared_ptr<mblas::Tensor> W2_;
    const std::shared_ptr<mblas::Tensor> B2_;
    const std::shared_ptr<mblas::Tensor> W3_;
    const std::shared_ptr<mblas::Tensor> B3_;
    const std::shared_ptr<mblas::Tensor> W4_;
    const std::shared_ptr<mblas::Tensor> B4_;
    const std::shared_ptr<mblas::Tensor> Gamma_0_;
    const std::shared_ptr<mblas::Tensor> Gamma_1_;
    const std::shared_ptr<mblas::Tensor> Gamma_2_;
  };

  ////////////////////////////////////////////////////////////////////////////////////////////////////

  Weights(const std::string& npzFile, const YAML::Node& config,  unsigned device);

  Weights(const NpzConverter& model, const YAML::Node& config, unsigned device);

  Weights(const Weights&) = delete;

  unsigned GetDevice() {
    return device_;
  }

private:
  void initEncForward(const NpzConverter& model,std::string celltype);
  void initEncBackward(const NpzConverter& model,std::string celltype);
  void initDec1(const NpzConverter& model,std::string celltype);
  void initDec2(const NpzConverter& model,std::string celltype);

public:
  const EncEmbeddings encEmbeddings_;
  const DecEmbeddings decEmbeddings_;
  // of these usuall only two at a time will be not null
  std::shared_ptr<EncForwardGRU> encForwardGRU_;
  std::shared_ptr<EncBackwardGRU> encBackwardGRU_;
  std::shared_ptr<EncForwardLSTM> encForwardLSTM_;
  std::shared_ptr<EncBackwardLSTM> encBackwardLSTM_;
  std::shared_ptr<MultWeights<EncForwardLSTM>> encForwardMLSTM_;
  std::shared_ptr<MultWeights<EncBackwardLSTM>> encBackwardMLSTM_;
  const DecInit decInit_;
  // of these usuall only two at a time will be not null
  std::shared_ptr<DecGRU1> decGru1_;
  std::shared_ptr<DecGRU2> decGru2_;
  std::shared_ptr<DecLSTM1> decLSTM1_;
  std::shared_ptr<DecLSTM2> decLSTM2_;
  std::shared_ptr<MultWeights<DecLSTM1>> decMLSTM1_;
  std::shared_ptr<MultWeights<DecLSTM2>> decMLSTM2_;
  const DecAlignment decAlignment_;
  const DecSoftmax decSoftmax_;

  const unsigned device_;
};

}
}