Welcome to mirror list, hosted at ThFree Co, Russian Federation.

tensor_operators.h « tensors « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: dc29bf356f411a6853b5495e087fcd2646605f72 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
#pragma once

#include "common/definitions.h"
#include "tensors/allocator.h"
#include "tensors/tensor.h"

#include "tensors/dispatch.h"

#include "functional/shape.h"
#include "functional/tensor.h"
#include "functional/tmp.h"

#ifdef CUDA_FOUND
#include "tensors/gpu/add.h"
#include "tensors/gpu/algorithm.h"
#include "tensors/gpu/element.h"
#include "tensors/gpu/prod.h"
#endif

#include "tensors/cpu/add.h"
#include "tensors/cpu/element.h"

#include <algorithm>

namespace marian {

template <typename InIt, typename OutIt>
void copy(Ptr<Backend>& backend, const InIt beg, const InIt end, OutIt it) {
#ifdef CUDA_FOUND
  if(backend->getDeviceId().type == DeviceType::gpu)
    gpu::copy(backend, beg, end, it);
  else
    std::copy(beg, end, it);
#else
    backend;
    std::copy(beg, end, it);
#endif
}

DISPATCH2(CopyCast, marian::Tensor, const marian::Tensor);
DISPATCH2(AddCast, marian::Tensor, const marian::Tensor);
DISPATCH4(IsNaN, const Tensor, Ptr<Allocator>, bool&, bool&);

#ifdef CUDA_FOUND
namespace gpu {
bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf);
}
#endif

namespace cpu {
bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf);
}

static inline bool SanitizeGradient(marian::Tensor in, Ptr<Allocator> allocator, bool pruneNaN, bool clipInf) {
#ifdef CUDA_FOUND
  if(in->getBackend()->getDeviceId().type == DeviceType::gpu)
    return gpu::SanitizeGradient(in, allocator, pruneNaN, clipInf);
  else
#endif
    return cpu::SanitizeGradient(in, allocator, pruneNaN, clipInf);
}

template <class Functor, class... Tensors>
void Element(Functor functor, marian::Tensor out, Tensors... tensors) {
#ifdef CUDA_FOUND
  if(out->getBackend()->getDeviceId().type == DeviceType::gpu)
    gpu::Element(functor, out, tensors...);
  else
#endif
    cpu::Element(functor, out, tensors...);
}

template <class Functor, class... Tensors>
void Add(Functor functor, float scale, marian::Tensor out, Tensors... tensors) {
#ifdef CUDA_FOUND
  if(out->getBackend()->getDeviceId().type == DeviceType::gpu)
    gpu::Add(functor, scale, out, tensors...);
  else
#endif
    cpu::Aggregate(functor, /*aggInit=*/0.0f, functional::_1 + functional::_2, scale, out, tensors...);
}

template <class Functor, class... Tensors>
void Add(Functor functor, marian::Tensor out, Tensors... tensors) {
  Add(functor, /*scale=*/1.f, out, tensors...);
}

template <class Functor, class AggFunctor, class... Tensors>
void Aggregate(Functor functor, float aggInit, AggFunctor aggFunctor, marian::Tensor out, Tensors... tensors) {
#ifdef CUDA_FOUND
  if(out->getBackend()->getDeviceId().type == DeviceType::gpu)
    gpu::Aggregate(functor, aggInit, aggFunctor, 1.0f, out, tensors...);
  else
#endif
    cpu::Aggregate(functor, aggInit, aggFunctor, 1.0f, out, tensors...);
}

template <class Functor, class... Tensors>
void Reduce(Functor functor,
            float scale,
            marian::Tensor out,
            Tensors... tensors) {
  out->set(0.f);
  Add(functor, scale, out, tensors...);
}

template <class Functor, class... Tensors>
void Reduce(Functor functor, marian::Tensor out, Tensors... tensors) {
  out->set(0.f);
  Add(functor, out, tensors...);
}

template <class Functor, class AggFunctor, class... Tensors>
void Reduce(Functor functor, AggFunctor aggFunctor, float aggInit,
            marian::Tensor out,
            Tensors... tensors) {
  out->set(aggInit);
  Aggregate(functor, aggInit, aggFunctor, out, tensors...);
}

// clang-format off
DISPATCH7(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float)
DISPATCH8(Prod, marian::Tensor, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float, Type) // overloading since we want the default to for computeType be C->type() which difficult otherwise.

DISPATCH8(ProdBatched, marian::Tensor, Ptr<Allocator>, const marian::Tensor, const marian::Tensor, bool, bool, float, float)
DISPATCH8(ProdBatchedLegacy, marian::Tensor, Ptr<Allocator>, const marian::Tensor, const marian::Tensor, bool, bool, float, float)
DISPATCH9(CSRProd, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float)

DISPATCH10(Affine, marian::Tensor, Ptr<Allocator>, const marian::Tensor&, const marian::Tensor&, const marian::Tensor&, bool, bool, float, float, bool)

DISPATCH2(Softmax, marian::Tensor, marian::Tensor)
DISPATCH3(SoftmaxGrad, marian::Tensor, marian::Tensor, marian::Tensor)

DISPATCH2(LogSoftmax, marian::Tensor, marian::Tensor)
DISPATCH3(LogSoftmaxGrad, marian::Tensor, marian::Tensor, marian::Tensor)

DISPATCH4(CrossEntropyPick, marian::Tensor, marian::Tensor, marian::Tensor, float)
DISPATCH5(CrossEntropyPickBackward, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, float)

DISPATCH3(TransposeND, marian::Tensor, marian::Tensor, const std::vector<int>&)
DISPATCH3(TransposeNDGrad, marian::Tensor, marian::Tensor, const std::vector<int>&)

DISPATCH5(Shift, marian::Tensor, marian::Tensor, marian::Shape, float, bool)
DISPATCH4(ShiftGrad, marian::Tensor, marian::Tensor, marian::Shape, bool)

DISPATCH3(Concatenate, marian::Tensor, const std::vector<marian::Tensor>&, int)

// clang-format on

// Bernoulli(tensor, 0.5f, 2.f, -1.f) generates a tensor composed of 50% of 1 and 50% of -1.
static inline void Bernoulli(Tensor resultTensor, float keepProb, float scale = 1.f, float shift = 0.f) {
  // in-place uniform distribution
  auto rnd = resultTensor->getBackend()->getRandomGenerator();
  rnd->uniform(resultTensor, 0.f, 1.f); // temporarily mis-use this to hold the random numbers
  using namespace functional;
  Element(_1 = (_1 < keepProb) * scale + shift, resultTensor);
}

static inline void Dropout(Tensor tensor, float dropProb) {
  float keepProb = 1.f - dropProb;
  float scale = 1.f / keepProb;
  Bernoulli(tensor, keepProb, scale, /*shift=*/0.f);
}

DISPATCH2(SinusoidalPositionEmbeddings, marian::Tensor, int);

#ifdef CUDA_FOUND
namespace gpu {
void Deconcatenate(std::vector<marian::Tensor>& outputs,
                   const marian::Tensor in,
                   int ax);
}
#endif

namespace cpu {
void Deconcatenate(std::vector<marian::Tensor>& outputs,
                   const marian::Tensor in,
                   int ax);
}

static inline void Deconcatenate(std::vector<marian::Tensor>& outputs,
                                 const marian::Tensor in,
                                 int ax) {
#ifdef CUDA_FOUND
  if(in->getBackend()->getDeviceId().type == DeviceType::gpu)
    gpu::Deconcatenate(outputs, in, ax);
  else
#endif
    cpu::Deconcatenate(outputs, in, ax);
}

// clang-format off
DISPATCH5(LayerNormalization, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, float)

#ifdef CUDA_FOUND
namespace gpu {
void LayerNormalizationGrad(Ptr<Allocator> allocator,
                            Tensor gradX,
                            Tensor gradGamma,
                            Tensor gradBeta,
                            Tensor adj,
                            Tensor y,
                            Tensor x,
                            Tensor gamma,
                            Tensor beta,
                            float eps);
}
#endif

namespace cpu {
void LayerNormalizationGrad(Tensor gradX,
                            Tensor gradGamma,
                            Tensor gradBeta,
                            Tensor adj,
                            Tensor y,
                            Tensor x,
                            Tensor gamma,
                            Tensor beta,
                            float eps);
}

static inline void LayerNormalizationGrad(
                            Ptr<Allocator> allocator,
                            Tensor gradX,
                            Tensor gradGamma,
                            Tensor gradBeta,
                            Tensor adj,
                            Tensor y,
                            Tensor x,
                            Tensor gamma,
                            Tensor beta,
                            float eps) {
#ifdef CUDA_FOUND
  if(gradX->getBackend()->getDeviceId().type == DeviceType::gpu)
    gpu::LayerNormalizationGrad(allocator, gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
  else
#endif
    cpu::LayerNormalizationGrad(gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
}

// clang-format off
DISPATCH5(RMSNormalization, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, float)

#ifdef CUDA_FOUND
namespace gpu {
void RMSNormalizationGrad(Ptr<Allocator> allocator,
                          Tensor gradX,
                          Tensor gradGamma,
                          Tensor gradBeta,
                          Tensor adj,
                          Tensor y,
                          Tensor x,
                          Tensor gamma,
                          Tensor beta,
                          float eps);
}
#endif

namespace cpu {
void RMSNormalizationGrad(Tensor gradX,
                          Tensor gradGamma,
                          Tensor gradBeta,
                          Tensor adj,
                          Tensor y,
                          Tensor x,
                          Tensor gamma,
                          Tensor beta,
                          float eps);
}

static inline void RMSNormalizationGrad(
                            Ptr<Allocator> allocator,
                            Tensor gradX,
                            Tensor gradGamma,
                            Tensor gradBeta,
                            Tensor adj,
                            Tensor y,
                            Tensor x,
                            Tensor gamma,
                            Tensor beta,
                            float eps) {
#ifdef CUDA_FOUND
  if(gradX->getBackend()->getDeviceId().type == DeviceType::gpu)
    gpu::RMSNormalizationGrad(allocator, gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
  else
#endif
    cpu::RMSNormalizationGrad(gradX, gradGamma, gradBeta, adj, y, x, gamma, beta, eps);
}

DISPATCH4(HighwayForward, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor)
DISPATCH7(HighwayBackward, marian::Tensor, marian::Tensor, marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor, const marian::Tensor)

DISPATCH3(CopyRows, marian::Tensor, const marian::Tensor, const marian::Tensor)
DISPATCH3(PasteRows, marian::Tensor, const marian::Tensor, const marian::Tensor)

DISPATCH3(CopyCols, marian::Tensor, const marian::Tensor, const marian::Tensor)
DISPATCH3(PasteCols, marian::Tensor, const marian::Tensor, const marian::Tensor)

DISPATCH4(Select, marian::Tensor, const marian::Tensor, const marian::Tensor, int)
DISPATCH4(Insert, marian::Tensor, const marian::Tensor, const marian::Tensor, int)

DISPATCH7(TopK, marian::Tensor, marian::Tensor, Ptr<Allocator>, const marian::Tensor, int, int, bool);

DISPATCH2(LSTMCellForward, marian::Tensor, std::vector<marian::Tensor>)
DISPATCH2(LSTMOutputForward, marian::Tensor, std::vector<marian::Tensor>);
// clang-format on

#ifdef CUDA_FOUND
namespace gpu {
void LSTMCellBackward(std::vector<marian::Tensor> outputs,
                      std::vector<marian::Tensor> inputs,
                      marian::Tensor adj);
}
#endif

namespace cpu {
void LSTMCellBackward(std::vector<marian::Tensor> outputs,
                      std::vector<marian::Tensor> inputs,
                      marian::Tensor adj);
}

static inline void LSTMCellBackward(std::vector<marian::Tensor> outputs,
                                    std::vector<marian::Tensor> inputs,
                                    marian::Tensor adj) {
#ifdef CUDA_FOUND
  if(adj->getBackend()->getDeviceId().type == DeviceType::gpu)
    gpu::LSTMCellBackward(outputs, inputs, adj);
  else
#endif
    cpu::LSTMCellBackward(outputs, inputs, adj);
}

#ifdef CUDA_FOUND
namespace gpu {
void LSTMOutputBackward(std::vector<marian::Tensor> outputs,
                        std::vector<marian::Tensor> inputs,
                        marian::Tensor adj);
}
#endif

namespace cpu {
void LSTMOutputBackward(std::vector<marian::Tensor> outputs,
                        std::vector<marian::Tensor> inputs,
                        marian::Tensor adj);
}

static inline void LSTMOutputBackward(std::vector<marian::Tensor> outputs,
                                      std::vector<marian::Tensor> inputs,
                                      marian::Tensor adj) {
#ifdef CUDA_FOUND
  if(adj->getBackend()->getDeviceId().type == DeviceType::gpu)
    gpu::LSTMOutputBackward(outputs, inputs, adj);
  else
#endif
    cpu::LSTMOutputBackward(outputs, inputs, adj);
}

DISPATCH3(GRUFastForward, marian::Tensor, std::vector<marian::Tensor>, bool)

#ifdef CUDA_FOUND
namespace gpu {
void GRUFastBackward(Ptr<Allocator> allocator,
                     std::vector<marian::Tensor> outputs,
                     std::vector<marian::Tensor> inputs,
                     marian::Tensor adj,
                     bool final);
}
#endif

namespace cpu {
void GRUFastBackward(Ptr<Allocator> allocator,
                     std::vector<marian::Tensor> outputs,
                     std::vector<marian::Tensor> inputs,
                     marian::Tensor adj,
                     bool final);
}

static inline void GRUFastBackward(Ptr<Allocator> allocator,
                                   std::vector<marian::Tensor> outputs,
                                   std::vector<marian::Tensor> inputs,
                                   marian::Tensor adj,
                                   bool final = false) {
#ifdef CUDA_FOUND
  if(adj->getBackend()->getDeviceId().type == DeviceType::gpu)
    gpu::GRUFastBackward(allocator, outputs, inputs, adj, final);
  else
#endif
    cpu::GRUFastBackward(allocator, outputs, inputs, adj, final);
}

// clang-format off
DISPATCH4(Att, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor)
DISPATCH7(AttBack, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor)
// clang-format on

#ifdef CUDA_FOUND
namespace gpu {
float L2Norm(marian::Tensor in, Ptr<Allocator> allocator);
}
#endif

namespace cpu {
float L2Norm(marian::Tensor in, Ptr<Allocator> allocator);
}

static inline float L2Norm(marian::Tensor in, Ptr<Allocator> allocator) {
#ifdef CUDA_FOUND
  if(in->getBackend()->getDeviceId().type == DeviceType::gpu)
    return gpu::L2Norm(in, allocator);
  else
#endif
    return cpu::L2Norm(in, allocator);
}

// clang-format off
DISPATCH5(PoolingWithMaskingForward, marian::Tensor, marian::Tensor, marian::Tensor, int, bool)
DISPATCH6(PoolingWithMaskingBackward, marian::Tensor, marian::Tensor, marian::Tensor, marian::Tensor, int, bool)
// clang-format on
}  // namespace marian