Welcome to mirror list, hosted at ThFree Co, Russian Federation.

element.cu « gpu « tensors « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 8525b71b36478712654f3ab9ab2359b1f55b7000 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include "tensors/gpu/element.h"

#include "functional/array.h"
#include "functional/functional.h"
#include "functional/tensor.h"
#include "functional/tmp.h"

#include "tensors/gpu/cuda_helpers.h"

namespace marian {
namespace gpu {

template <size_t K, bool broadcast, class Functor, typename T>
__global__ void gElement(
    Functor functor,
    functional::Array<functional::Tensor<T>, K> tensors) {
  int length = tensors[0].shape().elements();
  functional::Array<int, functional::Shape::size()> dims;
  functional::Array<int, K> indices;

  for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
    int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
    if(index < length) {
      indices.fill(index);

      if(broadcast) {
        tensors[0].shape().dims(index, dims);
        for(int i = 1; i < K; ++i)
          indices[i] = tensors[i].shape().bindex(dims);
      }

      tensors[0].data()[index] = functional::apply(functor, tensors, indices);
    }
  }
}


template <typename T, class Functor, class... Tensors>
void ElementTyped(Functor functor, Tensor out, Tensors... tensors) {
  //matchOrAbort<T>(out->type()); // @TODO: figure out undefined reference

  cudaSetDevice(out->getDeviceId().no);

  int length = out->shape().elements();
  int threads = std::min(MAX_THREADS, length);
  int blocks = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));

  constexpr size_t K = sizeof...(tensors) + 1;
  functional::Array<functional::Tensor<T>, K> gTensors = {out, tensors...};

  bool broadcast = false;
  for(int i = 1; i < K; ++i)
    broadcast = broadcast || gTensors[0].shape() != gTensors[i].shape();
  if(broadcast)
    gElement<K, true><<<blocks, threads>>>(functor, gTensors);
  else
    gElement<K, false><<<blocks, threads>>>(functor, gTensors);
}

template <class Functor, class... Tensors>
void Element(Functor functor, Tensor out, Tensors... tensors) {
  checkCommonType(out, tensors...);

  if(out->type() == Type::float32) {
    ElementTyped<float>(functor, out, tensors...);
  } else if(out->type() == Type::float16) {
#if COMPILE_FP16
    ElementTyped<__half>(functor, out, tensors...);
#else
    ABORT("FP16 not supported with chosen current hardware or CUDA version");
#endif
  } else if(out->type() == Type::float64) {
    ElementTyped<double>(functor, out, tensors...);
  } else {
    ABORT("Type {} not yet supported", out->type());
  }
}

#include "tensors/gpu/element.inc"
}  // namespace gpu
}  // namespace marian