Welcome to mirror list, hosted at ThFree Co, Russian Federation.

expression_operators.cpp « graph « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 41c86f74ffbaade867fdfcbf59fc4f7279db7836 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
#include "graph/expression_operators.h"
#include "layers/constructors.h"

#include "graph/node_operators.h"
#include "graph/node_operators_binary.h"
#include "graph/node_operators_unary.h"

#include "graph/auto_tuner.h"
#include "tensors/cpu/int16.h"

namespace marian {

Expr debug(Expr a, const std::string& message) {
  a->debug(message);
  return a;
}

// logistic function. Note: scipy name is expit()
Expr sigmoid(Expr a) {
  return Expression<SigmoidNodeOp>(a);
}

Expr relu(Expr a) {
  return Expression<ReLUNodeOp>(a);
}

Expr leakyrelu(Expr a) {
  return Expression<PReLUNodeOp>(0.01f, a);
}

Expr prelu(Expr a, float alpha) {
  return Expression<PReLUNodeOp>(alpha, a);
}

Expr clip(Expr a, float c) {
  if(c == 0)
    return a;
  else
    return Expression<ClipNodeOp>(a, c);
}

Expr log(Expr a) {
  return Expression<LogNodeOp>(a);
};

Expr exp(Expr a) {
  return Expression<ExpNodeOp>(a);
};

Expr swish(Expr a) {
  return Expression<SwishNodeOp>(a);
}

Expr operator-(Expr a) {
  return Expression<NegNodeOp>(a);
};

Expr softmax(Expr a, int axis /*=-1*/)
{
  // @TODO: move axis parameter down into the kernel
  if (axis != -1)
  {
    return swapAxes(softmax(swapAxes(a,
                                     axis, -1),
                            /*axis=*/-1),
                    axis, -1);
  }
  return Expression<SoftmaxNodeOp>(a);
}

Expr softmax(Expr a, Expr zeroOneMask, int axis /*=-1*/) {
  auto logMask = (1 - zeroOneMask) * -99999999.f;
  return softmax(a + logMask, axis);
}

Expr logsoftmax(Expr a) {
  return Expression<LogSoftmaxNodeOp>(a);
}

/*********************************************************/

Expr operator+(Expr a, Expr b) {
  return Expression<PlusNodeOp>(a, b);
}

Expr operator-(Expr a, Expr b) {
  return Expression<MinusNodeOp>(a, b);
}

Expr operator*(Expr a, Expr b) {
  return Expression<MultNodeOp>(a, b);
}

Expr operator/(Expr a, Expr b) {
  return Expression<DivNodeOp>(a, b);
}

Expr logaddexp(Expr a, Expr b) {
  return Expression<LogAddExpNodeOp>(a, b);
}

Expr maximum(Expr a, Expr b) {
  return Expression<MaximumNodeOp>(a, b);
}

Expr minimum(Expr a, Expr b) {
  return Expression<MinimumNodeOp>(a, b);
}

Expr lt(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b, -1, false); }
Expr eq(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b,  0, false); }
Expr gt(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b,  1, false); }
Expr ge(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b, -1,  true); }
Expr ne(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b,  0,  true); }
Expr le(Expr a, Expr b) { return Expression<CmpNodeOp>(a, b,  1,  true); }

Expr lt(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::from_value(a), b->value_type()), b, -1, false); }
Expr eq(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::from_value(a), b->value_type()), b,  0, false); }
Expr gt(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::from_value(a), b->value_type()), b,  1, false); }
Expr ge(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::from_value(a), b->value_type()), b, -1,  true); }
Expr ne(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::from_value(a), b->value_type()), b,  0,  true); }
Expr le(float a, Expr b) { return Expression<CmpNodeOp>(b->graph()->constant({}, inits::from_value(a), b->value_type()), b,  1,  true); }

Expr lt(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::from_value(b), a->value_type()), -1, false); }
Expr eq(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::from_value(b), a->value_type()),  0, false); }
Expr gt(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::from_value(b), a->value_type()),  1, false); }
Expr ge(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::from_value(b), a->value_type()), -1,  true); }
Expr ne(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::from_value(b), a->value_type()),  0,  true); }
Expr le(Expr a, float b) { return Expression<CmpNodeOp>(a, a->graph()->constant({}, inits::from_value(b), a->value_type()),  1,  true); }

/*********************************************************/

Expr operator+(Expr a, float b) {
  return Expression<ScalarAddNodeOp>(a, b);
}

Expr operator+(float a, Expr b) {
  return Expression<ScalarAddNodeOp>(b, a);
}

Expr operator-(Expr a, float b) {
  return Expression<ScalarAddNodeOp>(a, -b);
}

Expr operator-(float a, Expr b) {
  return Expression<ScalarAddNodeOp>(-b, a);
}

Expr operator*(float a, Expr b) {
  return Expression<ScalarMultNodeOp>(b, a);
}

Expr operator*(Expr a, float b) {
  return Expression<ScalarMultNodeOp>(a, b);
}

Expr operator/(Expr a, float b) {
  return Expression<ScalarMultNodeOp>(a, 1.f / b);
}

// TODO: efficient version of this without constant()
Expr operator/(float a, Expr b) {
  auto aExpr = b->graph()->constant({}, inits::from_value(a));
  return aExpr / b;
}

// Expr pow(float a, Expr b) {
//  return Expression<Scalar1PowNodeOp>(a, b);
//
//}
//
// Expr pow(Expr a, float b) {
//  return Expression<Scalar2PowNodeOp>(a, b);
//
//}
//
// Expr pow(Expr a, Expr b) {
//  return Expression<PowNodeOp>(a, b);
//}

/*********************************************************/

Expr concatenate(const std::vector<Expr>& concats, int ax) {
  return Expression<ConcatenateNodeOp>(concats, ax);
}

Expr repeat(Expr a, size_t repeats, int ax) {
  if(repeats == 1)
    return a;
  return concatenate(std::vector<Expr>(repeats, a), ax);
}

Expr reshape(Expr a, Shape shape) {
  return Expression<ReshapeNodeOp>(a, shape);
}

Expr atleast_1d(Expr a) {
  return atleast_nd(a, 1);
}

Expr atleast_2d(Expr a) {
  return atleast_nd(a, 2);
}

Expr atleast_3d(Expr a) {
  return atleast_nd(a, 3);
}

Expr atleast_4d(Expr a) {
  return atleast_nd(a, 4);
}

Expr atleast_nd(Expr a, size_t dims) {
  if(a->shape().size() >= dims)
    return a;

  Shape nShape;
  nShape.resize(dims);
  for(int i = 1; i <= (int)a->shape().size(); ++i)
    nShape.set(-i, a->shape()[-i]);

  return reshape(a, nShape);
}

Expr flatten(Expr a) {
  Shape shape = {a->shape().elements()};
  return Expression<ReshapeNodeOp>(a, shape);
}

Expr flatten_2d(Expr a) {
  Shape shape = {a->shape().elements() / a->shape()[-1], a->shape()[-1]};
  return Expression<ReshapeNodeOp>(a, shape);
}

Expr stopGradient(Expr a) {
  // implemented as a dummy reshape that is not trainable
  auto res = reshape(a, a->shape());
  res->setTrainable(false);
  return res;
}

Expr constant_like(Expr a, const NodeInitializer& init) {
  const auto& shape = a->shape();
  auto graph = a->graph();
  return graph->constant(shape, init);
}

// gather() -- gather arbitrary elements along an axis; batched or non-batched
Expr gather(Expr a, Expr indices, int axis) {
  return Expression<GatherNodeOp>(a, indices, axis);
}

// index_select() -- gather arbitrary elements along an axis; unbatched (indices are specified as a 1D vector)
Expr index_select(Expr a, Expr indices, int axis) {
  ABORT_IF(indices->shape().size() != 1, "Indices must be a 1D tensor");
  // We have specialized kernels for non-batched indexing of first or last axis of a 2D tensor.
  auto rank = a->shape().size();
  if (rank == 2) {
    if (axis == 0)
      return Expression<RowsNodeOp>(a, indices);
    else if (axis == -1 || axis == 1)
      return Expression<ColsNodeOp>(a, indices);
  }
  // Delegate to gather() for any other axis or non-matrix input.
  Shape shape;
  shape.resize(a->shape().size());
  shape.set(axis, indices->shape()[0]);
  indices = reshape(indices, shape); // move index to axis
  return gather(a, indices, axis);
}
Expr index_select(Expr a, const std::vector<IndexType>& indices, int axis) {
  auto indexExpr = a->graph()->indices(indices);
  return index_select(a, indexExpr, axis);
}

static Expr sliceCopy(Expr a, const Slice& slice, int axis) { // copy a Slice via gather()
  ABORT_IF(slice.stride < 0, "Negative strides are not supported yet");
  ABORT_IF(slice.begin == slice.end, "Empty slices are not allowed"); // @TODO: Or are they?
  std::vector<IndexType> indices;
  indices.reserve((slice.end - slice.begin - 1) / slice.stride + 1);
  for (int i = slice.begin; i < slice.end; i += slice.stride)
    indices.push_back((IndexType)i);
  return gather(a, a->graph()->indices(indices, a, axis), axis);
}

static Expr sliceView(Expr a, const Slice& slice, int axis) { // view a slice (must be memory-consecutive)
  return Expression<SliceViewNodeOp>(a, slice, axis);
}

// slice() -- gather a slice along an axis (step size > 1 allowed)
Expr slice(Expr a, Slice slice, int axis) { // numpy __getslice__ semantics, but with axis parameter
  const auto& shape = a->shape();
  axis  = shape.axis(axis);         // normalize negative axis
  slice = shape.slice(slice, axis); // normalize negative slice values
  if (slice.begin == 0 && slice.end == shape[axis] && slice.stride == 1)
    return a; // it's a no-op
#if 1 // until strided views are supported, non-consecutive slices are implemented via gather()
  if (slice.stride != 1)
    return sliceCopy(a, slice, axis);
  for (int i = 0; i < axis; ++i) {
    if (shape[i] != 1)  // this makes it non-consecutive
      return sliceCopy(a, slice, axis);
  }
#endif
  return sliceView(a, slice, axis);
}

Expr sum(Expr a, int ax) {
  return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::sum);
}

Expr mean(Expr a, int ax) {
  return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::mean);
}

Expr std(Expr a, int ax) {
  return Expression<ReduceNodeOp>(a - mean(a,ax), ax, ReduceNodeOpCode::rms);
}

Expr var(Expr a, int ax) {
  return Expression<ReduceNodeOp>(a - mean(a, ax), ax, ReduceNodeOpCode::meanSqr);
}

Expr max(Expr a, int ax) {
  return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::max);
}

Expr min(Expr a, int ax) {
  return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::min);
}

Expr prod(Expr a, int ax) {
  return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::prod);
}

// log(sum(exp(a)))
Expr logsumexp(Expr a, int ax) {
  return Expression<ReduceNodeOp>(a, ax, ReduceNodeOpCode::logSumExp);
}

Expr scalar_product(Expr a, Expr b, int ax) {
  return Expression<ScalarProductNodeOp>(a, b, ax);
}

Expr weighted_average(Expr in, Expr weights, int ax) {
  auto p = scalar_product(in, weights, ax);
  auto s = sum(weights, ax);
  return p / s;
}

Expr dot(Expr a, Expr b, bool transA, bool transB, float scale) {
  auto device = a->graph()->getDeviceId().type;
  float clipValue = a->graph()->getBackend()->getClip();

  // Currently only true when command line options
  // --optimize --cpu-thread=N with N > 0 are set.
  if(a->graph()->isOptimized() && device == DeviceType::cpu) {
    // dotInt16 computes A * B.T, hence the transpose for B to get A * B
    // if transA = false and transB = false.

    return cpu::int16::dot(
        cpu::int16::quantize(transA ? transpose(a) : a, clipValue),
        cpu::int16::quantize(transB ? b : transpose(b), clipValue),
        scale);
  } else {
    return Expression<DotNodeOp>(
        clip(a, clipValue), clip(b, clipValue), transA, transB, scale);
  }
}

Expr bdot(Expr a, Expr b, bool transA, bool transB, float scale) {
  return Expression<DotBatchedNodeOp>(a, b, transA, transB, scale);
}

Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
  auto device = a->graph()->getDeviceId().type;

  float clipValue = a->graph()->getBackend()->getClip();

  if(a->graph()->isOptimized() && device == DeviceType::cpu) {
    bool autotune = true;
    if(autotune) {
      thread_local Ptr<AutoTuner<Expr>> tuner = New<AutoTuner<Expr>>();

      // start with new set of algorithms
      tuner->clear();

      // lower precicion for shapes, reduces data sparsity
      auto sh = [](Shape sh) {
        for(size_t i = 0; i < sh.size(); ++i)
          sh.set(i, sh[i] / 4);
        return sh;
      };

      // create context for current call as hash
      std::size_t hash = sh(a->shape()).hash();
      util::hash_combine(hash, sh(b->shape()).hash());
      util::hash_combine(hash, sh(bias->shape()).hash());
      util::hash_combine(hash, transA);
      util::hash_combine(hash, transB);

      // add first algorithm variant (Int16)
      size_t hash1 = hash;
      util::hash_combine(hash1, 1);
      auto rec1 = [=](Expr e, bool stop = false) {
        e->record(tuner, hash1, stop);
        return e;
      };
      auto alg1 = [=]() {
        return rec1(
            cpu::int16::affine(
                rec1(cpu::int16::quantize(transA ? rec1(transpose(a)) : a,
                                          clipValue)),
                cpu::int16::quantize(transB ? b : transpose(b), clipValue),
                bias,
                scale),
            true);
      };
      tuner->insert({hash1, alg1});

      // add second algorithm variant (CBlas)
      size_t hash2 = hash;
      util::hash_combine(hash2, 2);
      auto rec2 = [=](Expr e, bool stop = false) {
        e->record(tuner, hash2, stop);
        return e;
      };

      auto alg2 = [=]() {
        auto ac = clip(a, clipValue);
        if(ac != a)
          ac = rec2(ac);

        auto bc = clip(b, clipValue);
        if(bc != b)
          bc = rec2(bc);

        int rows = ac->shape().elements() / ac->shape()[-1];
        Expr ones = ac->graph()->ones({rows, 1});
        std::vector<Expr> nodes = {ac, bc, bias, ones};
        return rec2(Expression<AffineNodeOp>(nodes, transA, transB, scale),
                    true);
      };
      tuner->insert({hash2, alg2});

      // execute algorithm with autotuning
      return tuner->run();

    } else {
      // cpu int16 version
      return cpu::int16::affine(
          cpu::int16::quantize(transA ? transpose(a) : a, clipValue),
          cpu::int16::quantize(transB ? b : transpose(b), clipValue),
          bias,
          scale);
    }
  } else {
    // general version, MKL, CBlas or CUDA

    // if clipValue > 0, the inputs will be clipped to range [-clipValue,
    // clipValue] This is meant to keep values at the same range as used during
    // training when optimizing for 8-bit integer products. Likely to be removed
    // in the future when we explore better ways to handle this.

    int rows = a->shape().elements() / a->shape()[-1];
    Expr ones = a->graph()->ones({rows, 1});
    std::vector<Expr> nodes
        = {clip(a, clipValue), clip(b, clipValue), bias, ones};
    return Expression<AffineNodeOp>(nodes, transA, transB, scale);
  }
}

// multiply a CSR matrix A with a matrix B
// A[i,j] is at A_values[A_offsets[i]+k], where k is position of j in A_indices[A_offsets[i]:A_offsets[i+1]]
// @TODO: Define a proper sparse tensor type.
Expr csr_dot(const Shape& A_shape, Expr A_values, Expr A_indices, Expr A_offsets, Expr B, bool transA /*= false*/) {
  return Expression<CSRDotNodeOp>(A_shape, A_values, A_indices, A_offsets, B, transA, /*swapOperands=*/false);
}

// multiply a matrix A with a CSR matrix B
// @TODO: Define a proper sparse tensor type.
Expr dot_csr(Expr A, const Shape& B_shape, Expr B_values, Expr B_indices, Expr B_offsets, bool transB /*= false*/) {
  return Expression<CSRDotNodeOp>(B_shape, B_values, B_indices, B_offsets, A, transB, /*swapOperands=*/true);
}

// swap the last two axes
// @TODO: change to swapAxes(a, -1, -2)
Expr transpose(Expr a) {
  std::vector<int> axes(a->shape().size());
  for(int i = 0; i < axes.size(); ++i) {
    axes[i] = i;
  }
  if(axes.size() > 1) {
    axes[axes.size() - 1] = (int)axes.size() - 2;
    axes[axes.size() - 2] = (int)axes.size() - 1;
  }
  return Expression<TransposeNodeOp>(a, axes);
}

Expr transpose(Expr a, const std::vector<int>& axes) {
  return Expression<TransposeNodeOp>(a, axes);
}

Expr swapAxes(Expr x, int axis1, int axis2)
{
  axis1 = x->shape().axis(axis1);
  axis2 = x->shape().axis(axis2);
  if (axis1 == axis2)
    return x;
  // TODO: This is code dup from transpose(x). Implement transpose(x) as swapAxes(x, 0, 1)
  std::vector<int> axes(x->shape().size());
  for (int i = 0; i < axes.size(); ++i)
    axes[i] = i;
  std::swap(axes[axis1], axes[axis2]);
  return transpose(x, axes);
}

Expr cross_entropy(Expr a, Expr indices) {
  return Expression<CrossEntropyNodeOp>(a, indices);
}

Expr plus(const std::vector<Expr>&) {
  ABORT("Not implemented");
}

Expr swish(const std::vector<Expr>&) {
  ABORT("Not implemented");
}

Expr tanh(const std::vector<Expr>& nodes) {
  return Expression<TanhNodeOp>(nodes);
}

Expr sigmoid(const std::vector<Expr>&) {
  ABORT("Not implemented");
}

Expr relu(const std::vector<Expr>&) {
  ABORT("Not implemented");
}

Expr leakyrelu(const std::vector<Expr>&) {
  ABORT("Not implemented");
}

Expr prelu(const std::vector<Expr>&, float /*alpha*/) {
  ABORT("Not implemented");
}

Expr sqrt(Expr a, float eps) {
  return Expression<SqrtNodeOp>(a, eps);
}

Expr square(Expr a) {
  return Expression<SquareNodeOp>(a);
}

Expr layerNorm(Expr x,
               Expr gamma,
               Expr beta /*= nullptr*/,
               float eps /*= 1e-9*/) {
  std::vector<Expr> nodes = {x, gamma};
  if(beta)
    nodes.push_back(beta);
  return Expression<LayerNormalizationOp>(nodes, eps);
}

Expr highway(Expr y, Expr x, Expr t) {
  std::vector<Expr> nodes = {y, x, t};
  return Expression<HighwayNodeOp>(nodes);
}

Expr highway(const std::string prefix, Expr x) {
  // clang-format off
  size_t outDim = x->shape()[-1];
  auto graph = x->graph();
  auto g = mlp::dense()
      ("prefix", prefix + "_highway_d1")
      ("dim", outDim)
      ("activation", mlp::act::sigmoid)
      .construct(graph)->apply(x);
  auto relued = mlp::dense()
      ("prefix", prefix + "_highway_d2")
      ("dim", outDim)
      ("activation", mlp::act::ReLU)
      .construct(graph)->apply(x);
  return (g * relued) + ((1 - g) * x);
  // clang-format on
}

// Expr batch_norm(Expr x, Expr gamma, Expr beta) {
//  auto mju = mean(x, keywords::axis=0);
//  auto xmmju = x - mju;
//  auto std = sqrt(mean(square(xmmju), keywords::axis=0), 1e-9);
//
//  if(beta)
//    return gamma * (xmmju / std) + beta;
//  else
//    return gamma * (xmmju / std);
//}

Expr shift(Expr a, Shape shift, float padValue) {
  return Expression<ShiftNodeOp>(a, shift, padValue);
}

// Expr lexical_bias(Expr logits, Expr att, float eps, Ptr<sparse::CSR> lf) {
//  return Expression<LexicalProbNodeOp>(logits, att, eps, lf);
//}

#ifdef CUDA_FOUND
#ifdef CUDNN

Expr avg_pooling(Expr x,
                 int height,
                 int width,
                 int padHeight,
                 int padWidth,
                 int strideHeight,
                 int strideWidth) {
  return Expression<PoolingOp>(
      x, height, width, padHeight, padWidth, strideHeight, strideWidth, "avg");
}

Expr max_pooling(Expr x,
                 int height,
                 int width,
                 int padHeight,
                 int padWidth,
                 int strideHeight,
                 int strideWidth) {
  return Expression<PoolingOp>(
      x, height, width, padHeight, padWidth, strideHeight, strideWidth, "max");
}

Expr convert2cudnnFormat(Expr x) {
  int numWords = x->shape()[0];
  int numExamples = x->shape()[1];
  int embSize = x->shape()[2];

  std::vector<IndexType> newIndeces;
  for(int b = 0; b < numExamples; ++b) {
    for(int t = 0; t < numWords; ++t) {
      newIndeces.push_back((t * numExamples) + b);
    }
  }

  auto xRows = reshape(x, {x->shape()[0] * x->shape()[1], x->shape()[2]});

  Shape outShape({numExamples, 1, numWords, embSize});
  return reshape(rows(xRows, newIndeces), outShape);
}

Expr convertFromcudnnFormat(Expr x) {
  int batchDim = x->shape()[0];
  int sentenceDim = x->shape()[2];
  int embSize = x->shape()[3];

  auto reshapedX = reshape(x, {batchDim * sentenceDim, embSize});

  std::vector<IndexType> newIndeces;
  for(int t = 0; t < sentenceDim; ++t) {
    for(int b = 0; b < batchDim; ++b) {
      newIndeces.push_back(b * sentenceDim + t);
    }
  }

  Shape shape({batchDim, sentenceDim, embSize});
  return reshape(rows(reshapedX, newIndeces), shape);
}

Expr pooling_with_masking(Expr x, Expr mask, int width, bool isEven) {
  return Expression<PoolingWithMaskingOp>(x, mask, width, isEven);
}

#endif
#endif
}  // namespace marian