Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'src/tensors/gpu/add.h')
-rw-r--r--src/tensors/gpu/add.h180
1 files changed, 180 insertions, 0 deletions
diff --git a/src/tensors/gpu/add.h b/src/tensors/gpu/add.h
new file mode 100644
index 00000000..13ffc500
--- /dev/null
+++ b/src/tensors/gpu/add.h
@@ -0,0 +1,180 @@
+#include "gpu/shape.h"
+#include "gpu/tmp.h"
+#include "gpu/tensor.h"
+#include "functional/functional.h"
+
+namespace marian {
+
+namespace gpu {
+
+#ifdef __CUDACC__
+template <size_t K, class Functor>
+__global__ void gAddGeneric(Functor functor,
+ const gpu::Shape full,
+ gpu::Tensor<float> out,
+ gpu::Array<gpu::Tensor<float>, K> ins,
+ float scale = 1.0) {
+
+ int outLength = out.shape().elements();
+ bool same = outLength == full.elements();
+ for(int i = 0; i < K; ++i)
+ same = same && outLength == ins[i].shape().elements();
+
+ constexpr size_t N = gpu::Shape::size();
+ gpu::Array<int, N> len;
+ for(int i = 0; i < N; ++i)
+ len[i] = full[i] / out.shape()[i];
+
+ gpu::Array<int, N> dims;
+ for(int bid = 0; bid < outLength; bid += blockDim.x * gridDim.x) {
+ int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
+ if(index < outLength) {
+
+ if(same) {
+ out[index] += gpu::apply(functor, ins, index) * scale;
+ } else {
+ out.shape().dims(index, dims);
+ out[index] += gpu::loops(functor, ins, len, dims) * scale;
+ }
+
+ }
+ }
+}
+
+template <size_t K, class Functor>
+__global__ void gAddEqual(Functor functor,
+ gpu::Tensor<float> out,
+ gpu::Array<gpu::Tensor<float>, K> ins,
+ float scale,
+ bool broadcast) {
+ int length = out.shape().elements();
+ gpu::Array<int, gpu::Shape::size()> dims;
+
+ for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
+ int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
+ if(index < length) {
+ gpu::Array<int, K> indices;
+ indices.fill(index);
+
+ if(broadcast) {
+ out.shape().dims(index, dims);
+ for(size_t i = 0; i < K; ++i)
+ indices[i] = ins[i].shape().bindex(dims);
+ }
+
+ out[index] += gpu::apply(functor, ins, indices) * scale;
+ }
+ }
+}
+
+template <size_t K, class Functor>
+__global__ void gAddReduce(Functor functor,
+ const gpu::Shape full,
+ gpu::Tensor<float> out,
+ gpu::Array<gpu::Tensor<float>, K> ins,
+ float scale = 1.0) {
+
+ int rows = full.elements() / full.back();
+ int cols = full.back();
+
+ bool same = true;
+ for(int i = 0; i < K; ++i)
+ same = same && ins[i].shape().elements() == full.elements();
+
+ for(int bid = 0; bid < rows; bid += gridDim.x) {
+ int j = bid + blockIdx.x;
+ if(j < rows) {
+ extern __shared__ float _share[];
+ float* _sum = _share + blockDim.x;
+
+ if(same) {
+ _sum[threadIdx.x] = 0;
+ for(int tid = 0; tid < cols; tid += blockDim.x) {
+ int id = tid + threadIdx.x;
+ if(id < cols)
+ _sum[threadIdx.x] += gpu::apply(functor, ins, j * cols + id);
+ }
+ } else {
+ gpu::Array<int, gpu::Shape::size()> dims;
+ _sum[threadIdx.x] = 0;
+
+ for(int tid = 0; tid < cols; tid += blockDim.x) {
+ int id = tid + threadIdx.x;
+ if(id < cols) {
+ full.dims(j * cols + id, dims);
+ gpu::Array<int, K> indices;
+ for(int i = 0; i < K; ++i)
+ indices[i] = ins[i].shape().bindex(dims);
+ _sum[threadIdx.x] += gpu::apply(functor, ins, indices);
+ }
+ }
+ }
+ __syncthreads();
+ int len = blockDim.x;
+ while(len != 1) {
+ __syncthreads();
+ int skip = (len + 1) >> 1;
+ if(threadIdx.x < (len >> 1)) {
+ _sum[threadIdx.x] += _sum[threadIdx.x + skip];
+ }
+ len = (len + 1) >> 1;
+ }
+ __syncthreads();
+ out[j] += _sum[0] * scale;
+ }
+ }
+}
+#endif
+
+template <class Functor, class ...Tensors>
+void Add(Functor functor,
+ float scale,
+ marian::Tensor out,
+ Tensors... tensors) {
+
+#ifdef __CUDACC__
+ cudaSetDevice(out->getDevice().no);
+
+ auto full = marian::Shape::broadcast({out, tensors...});
+
+ int length = out->shape().elements();
+
+ constexpr size_t K = sizeof...(Tensors);
+
+ gpu::Tensor<float> gOut = out;
+ gpu::Array<gpu::Tensor<float>, K> gIns = {tensors ...};
+
+ if(full.back() != 1 && out->shape().back() == 1) {
+ size_t m = full.elements() / length;
+ size_t k = full.back();
+
+ int blocks = std::min(MAX_BLOCKS, (int)m);
+ int threads = std::min(MAX_THREADS, (int)k);
+ int shared = sizeof(float) * threads * 2;
+
+ gAddReduce<<<blocks, threads, shared>>>(functor, full, gOut, gIns, scale);
+
+ } else if(out->shape() == full) {
+ int threads = std::min(MAX_THREADS, length);
+ int blocks
+ = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
+
+ bool broadcast = false;
+ for(int i = 0; i < K; ++i)
+ broadcast = broadcast || gOut.shape() != gIns[i].shape();
+
+ gAddEqual<<<blocks, threads>>>(functor, gOut, gIns, scale, broadcast);
+ } else {
+ int threads = std::min(MAX_THREADS, length);
+ int blocks
+ = std::min(MAX_BLOCKS, length / threads + (length % threads != 0));
+
+ gAddGeneric<<<blocks, threads>>>(functor, full, gOut, gIns, scale);
+ }
+#else
+ ABORT("Not implemented");
+#endif
+}
+
+}
+}