Welcome to mirror list, hosted at ThFree Co, Russian Federation.

tensor_operators.h « src - github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
blob: 7ec4ca681c250a2e6d7c114dc66e8f6c0855b8c3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#pragma once

#include "tensor.h"

namespace marian {

using namespace thrust::placeholders;
#define MAX_THREADS 512
#define MAX_BLOCKS 65535

template <class Functor>
__global__ void gElement(Functor functor, Float* out,
                         size_t rows, size_t cols) {
  for(int bid = 0; bid < rows; bid += gridDim.x) {
    int j = bid + blockIdx.x;
    if(j < rows) {
      Float* rowOut = out + j * cols;
      for(int tid = 0; tid < cols; tid += blockDim.x) {
        int i = tid + threadIdx.x;
        if(i < cols)
          rowOut[i] = functor(rowOut[i]);;
      }
    }
  }
}

template <class Functor>
__global__ void gElement(Functor functor,
                         Float* out, const Float* in,
                         size_t rows, size_t cols) {
  for(int bid = 0; bid < rows; bid += gridDim.x) {
    int j = bid + blockIdx.x;
    if(j < rows) {
      Float* rowOut = out + j * cols;
      const Float* rowIn = in + j * cols;

      for(int tid = 0; tid < cols; tid += blockDim.x) {
        int i = tid + threadIdx.x;
        if(i < cols)
          rowOut[i] = functor(rowOut[i], rowIn[i]);;
      }
    }
  }
}

template <class Functor>
__global__ void gElement(Functor functor,
                         Float* out, const Float* in1, const Float* in2,
                         size_t rows, size_t cols) {
  for(int bid = 0; bid < rows; bid += gridDim.x) {
    int j = bid + blockIdx.x;
    if(j < rows) {
      Float* rowOut = out + j * cols;
      const Float* rowIn1 = in1 + j * cols;
      const Float* rowIn2 = in2 + j * cols;

      for(int tid = 0; tid < cols; tid += blockDim.x) {
        int i = tid + threadIdx.x;
        if(i < cols)
          rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i]);
      }
    }
  }
}

template <class Functor>
__global__ void gElement(Functor functor,
                         Float* out, const Float* in1,
                         const Float* in2, const Float* in3,
                         size_t rows, size_t cols) {
  for(int bid = 0; bid < rows; bid += gridDim.x) {
    int j = bid + blockIdx.x;
    if(j < rows) {
      Float* rowOut = out + j * cols;
      const Float* rowIn1 = in1 + j * cols;
      const Float* rowIn2 = in2 + j * cols;
      const Float* rowIn3 = in3 + j * cols;

      for(int tid = 0; tid < cols; tid += blockDim.x) {
        int i = tid + threadIdx.x;
        if(i < cols)
          rowOut[i] = functor(rowOut[i], rowIn1[i], rowIn2[i], rowIn3[i]);
      }
    }
  }
}

// @TODO add broadcasting

template <class Functor>
void Element(Functor functor, Tensor Out) {
  Float* d_out = Out.data();
  int blocks  = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
  int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
  gElement<<<blocks, threads>>>(functor, d_out,
                                Out.shape()[0], Out.shape()[1]);
  cudaStreamSynchronize(0);
}

template <class Functor>
void Element(Functor functor,
             Tensor Out, const Tensor In) {
  Float* d_out = Out.data();
  const Float* d_in = In.data();

  int blocks  = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
  int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
  gElement<<<blocks, threads>>>(functor, d_out, d_in,
                                Out.shape()[0], Out.shape()[1]);
  cudaStreamSynchronize(0);
}

template <class Functor>
void Element(Functor functor,
             Tensor Out, const Tensor In1, const Tensor In2) {
  
  Float* d_out = Out.data();
  const Float* d_in1 = In1.data();
  const Float* d_in2 = In2.data();
  
  int blocks  = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
  int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
  gElement<<<blocks, threads>>>(functor, d_out, d_in1, d_in2,
                                Out.shape()[0], Out.shape()[1]);
  cudaStreamSynchronize(0);
}

template <class Functor>
void Element(Functor functor,
             Tensor Out, const Tensor In1,
             const Tensor In2, const Tensor In3) {
  
  Float* d_out = Out.data();
  const Float* d_in1 = In1.data();
  const Float* d_in2 = In2.data();
  const Float* d_in3 = In3.data();
  
  int blocks  = std::min(MAX_BLOCKS, (int)Out.shape()[0]);
  int threads = std::min(MAX_THREADS, (int)Out.shape()[1]);
  gElement<<<blocks, threads>>>(functor, d_out, d_in1, d_in2, d_in3,
                                Out.shape()[0], Out.shape()[1]);
  cudaStreamSynchronize(0);
}

Tensor Prod(cublasHandle_t handle, Tensor C, const Tensor A, const Tensor B,
             bool transA, bool transB, Float beta);

Tensor Prod(Tensor C, const Tensor A, const Tensor B,
             bool transA, bool transB, Float beta = 0);

}