Welcome to mirror list, hosted at ThFree Co, Russian Federation.

github.com/marian-nmt/marian.git - Unnamed repository; edit this file 'description' to name the repository.
summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-11-02 22:37:43 +0300
committerMarcin Junczys-Dowmunt <junczys@amu.edu.pl>2017-11-02 22:37:43 +0300
commitca64c429e4aa4dcd49b68bab6a7d744fe06b44c2 (patch)
treed6469621854e3a44ec43527ac843d67d9eb1f586 /src
parentcd6fca847dd9aa0b5236306beaaa8f778e2d8726 (diff)
new functional templates
Diffstat (limited to 'src')
-rw-r--r--src/CMakeLists.txt2
-rw-r--r--src/gpu/array.h9
-rw-r--r--src/gpu/defs.h22
-rw-r--r--src/gpu/functions.h98
-rw-r--r--src/gpu/placeholders.h96
-rw-r--r--src/gpu/tmp.h1
-rw-r--r--src/graph/node_operators_binary.h2
-rw-r--r--src/graph/node_operators_unary.h59
-rw-r--r--src/kernels/tensor_operators.h21
-rw-r--r--src/kernels/thrust_functions.h204
-rw-r--r--src/optimizers/clippers.cu6
-rw-r--r--src/optimizers/optimizers.cu17
-rw-r--r--src/tests/tensor_test.cu55
-rw-r--r--src/training/graph_group_async.cu3
14 files changed, 325 insertions, 270 deletions
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
index 7c2bc3e8..a177c0de 100644
--- a/src/CMakeLists.txt
+++ b/src/CMakeLists.txt
@@ -16,7 +16,7 @@ cuda_add_library(marian
tensors/memory_piece.cu
kernels/tensor_operators.cu
kernels/dropout.cu
- kernels/sparse.cu
+# kernels/sparse.cu
layers/param_initializers.cu
layers/generic.cpp
layers/guided_alignment.cpp
diff --git a/src/gpu/array.h b/src/gpu/array.h
index 65d80bfa..8981e8c7 100644
--- a/src/gpu/array.h
+++ b/src/gpu/array.h
@@ -1,18 +1,11 @@
#pragma once
-#include <cuda.h>
+#include "gpu/defs.h"
namespace marian {
namespace gpu {
-#define __H__ __host__
-#define __D__ __device__
-#define __HI__ __host__ inline
-#define __DI__ __device__ inline
-#define __HD__ __host__ __device__
-#define __HDI__ __host__ __device__ inline
-
template <typename T, size_t N>
struct Array {
typedef T value_type;
diff --git a/src/gpu/defs.h b/src/gpu/defs.h
new file mode 100644
index 00000000..6fab8f8e
--- /dev/null
+++ b/src/gpu/defs.h
@@ -0,0 +1,22 @@
+#pragma once
+
+#ifdef __CUDA_ARCH__
+
+#include <cuda.h>
+#define __H__ __host__
+#define __D__ __device__
+#define __HI__ __host__ inline
+#define __DI__ __device__ inline
+#define __HD__ __host__ __device__
+#define __HDI__ __host__ __device__ inline
+
+#else
+
+#define __H__
+#define __D__
+#define __HI__ inline
+#define __DI__ inline
+#define __HD__
+#define __HDI__ inline
+
+#endif \ No newline at end of file
diff --git a/src/gpu/functions.h b/src/gpu/functions.h
new file mode 100644
index 00000000..64f7c602
--- /dev/null
+++ b/src/gpu/functions.h
@@ -0,0 +1,98 @@
+#pragma once
+
+#include "gpu/defs.h"
+#include "gpu/placeholders.h"
+
+namespace marian {
+ namespace functional {
+
+ template <typename Function, typename X>
+ struct UnaryFunctor {
+ X x;
+
+ template <class Arg>
+ __HD__ UnaryFunctor(Arg a) : x(a) {}
+
+ template <typename ...Args>
+ __HDI__ float operator()(Args&&... args) {
+ return Function::apply(x(args...));
+ }
+ };
+
+ template <class Function, class X, class Y>
+ struct BinaryFunctor {
+ X x;
+ Y y;
+
+ template <class Arg1, class Arg2>
+ __HD__ BinaryFunctor(Arg1 arg1, Arg2 arg2) : x(arg1), y(arg2) {}
+
+ template <typename ...Args>
+ __HDI__ float operator()(Args&&... args) {
+ return Function::apply(x(args...), y(args...));
+ }
+ };
+
+ #define UNARY(name, name2, func) \
+ namespace elem { \
+ struct name { \
+ __HDI__ static float apply(float x) { return func; } \
+ }; \
+ }\
+ template <class X> using name = UnaryFunctor<elem::name, X>;\
+ template <typename X>\
+ __HDI__ Op<name<X>> name2(Op<X> x) {\
+ return Op<name<X>>(name<X>(x.f));\
+ }
+
+ #define BINARY(name, name2, func) \
+ namespace elem { \
+ struct name { \
+ __HDI__ static float apply(float x, float y) { return func; } \
+ }; \
+ }\
+ template <class X, class Y> using name = BinaryFunctor<elem::name, X, Y>;\
+ template <typename X, typename Y>\
+ __HDI__ Op<name<X, Y>> name2(Op<X> x, Op<Y> y) {\
+ return Op<name<X, Y>>(name<X, Y>(x.f, y.f));\
+ }\
+ template <typename X>\
+ __HDI__ Op<name<X, C>> name2(Op<X> x, float y) {\
+ return name2(x, Op<C>(y));\
+ }\
+ template <typename Y>\
+ __HDI__ Op<name<C, Y>> name2(float x, Op<Y> y) {\
+ return name2(Op<C>(x), y);\
+ }
+
+ UNARY(Tanh, tanh, tanhf(x));
+ UNARY(Sin, sin, sinf(x));
+ UNARY(Cos, cos, cosf(x));
+ UNARY(Tan, tan, tanf(x));
+ UNARY(Log, log, logf(x));
+ UNARY(Exp, exp, expf(x));
+ UNARY(Abs, abs, fabs(x));
+ UNARY(Sqrt, sqrt, sqrtf(x));
+ UNARY(Neg, operator-, -x);
+ UNARY(Logit, logit, x > 0 ? (1.f / (1.f + expf(-x))) : (expf(x) / (1.f + expf(x))));
+
+ BINARY(Plus, operator+, x + y);
+ BINARY(Minus, operator-, x - y);
+ BINARY(Mult, operator*, x * y);
+ BINARY(Div, operator/, x / y);
+ BINARY(Pow, pow, pow(x, y));
+
+ template <typename T>
+ __HDI__ T sgn(T val) {
+ return (float(0) < val) - (val < float(0));
+ }
+
+ BINARY(Clip, clip, fabs(x) >= y ? sgn(x) * y : x);
+
+ UNARY(sReLU, ReLU, x > 0.f ? x : 0.f);
+ UNARY(sReLUBack, ReLUback, x > 0.f ? 1.f : 0.f);
+ BINARY(sPReLU, PReLU, x > 0.f ? x : x * y);
+ BINARY(sPReLUBack, PReLUback, x > 0.f ? 1.f : y);
+
+ }
+} \ No newline at end of file
diff --git a/src/gpu/placeholders.h b/src/gpu/placeholders.h
new file mode 100644
index 00000000..7e3e72e0
--- /dev/null
+++ b/src/gpu/placeholders.h
@@ -0,0 +1,96 @@
+#pragma once
+
+#include "gpu/defs.h"
+
+namespace marian {
+ namespace functional {
+
+ template <int N>
+ struct Select {
+ template <typename T, typename ...Args>
+ __HDI__ static auto apply(T&& arg, Args&&... args) -> decltype(Select<N-1>::apply(args...)) {
+ return Select<N-1>::apply(args...);
+ }
+ };
+
+ template <>
+ struct Select<0> {
+ template <typename T, typename ...Args>
+ __HDI__ static T apply(T&& arg, Args&&... args) {
+ return arg;
+ }
+ };
+
+ template <int N>
+ struct X {
+
+ template <typename ...Args>
+ __HDI__ float& operator()(Args&&... args) {
+ return Select<N-1>::apply(args...);
+ }
+ };
+
+ struct C {
+ float value;
+
+ __HD__ C(const C& c) : value(c.value) {}
+ __HD__ C(float f) : value(f) {}
+
+ template <typename ...Args>
+ __HDI__ float& operator()(Args&&... args) { return value; }
+ };
+
+
+ template <class X, class Y>
+ struct Assign {
+ X x;
+ Y y;
+
+ template <class Arg1, class Arg2>
+ __HD__ Assign(Arg1&& arg1, Arg2&& arg2) : x(arg1), y(arg2) {}
+
+ template <typename ...Args>
+ __HDI__ float operator()(Args&&... args) {
+ return x(args...) = y(args...);
+ }
+ };
+
+ template <class F>
+ struct Op {
+ F f;
+
+ __HD__ Op() {}
+
+ template <class A>
+ __HD__ Op(A a) : f(a) {}
+
+ template <class X>
+ __HD__ Op<Assign<F, X>> operator=(Op<X> x) {
+ return Op<Assign<F, X>>(Assign<F, X>(f, x.f));
+ }
+
+ __HD__ Op<Assign<F, C>> operator=(float x) {
+ return Op<Assign<F, C>>(Assign<F, C>(f, C(x)));
+ }
+
+ template <typename ...Args>
+ __HDI__ float operator()(Args&&... args) {
+ return f(args...);
+ }
+ };
+
+ template <int N>
+ using ref = Op<X<N>>;
+
+ static ref<1> _1;
+ static ref<2> _2;
+ static ref<3> _3;
+ static ref<4> _4;
+ static ref<5> _5;
+ static ref<6> _6;
+ static ref<7> _7;
+ static ref<8> _8;
+ static ref<9> _9;
+
+ }
+} \ No newline at end of file
diff --git a/src/gpu/tmp.h b/src/gpu/tmp.h
index c6685d05..6a03fbd1 100644
--- a/src/gpu/tmp.h
+++ b/src/gpu/tmp.h
@@ -1,5 +1,6 @@
#pragma once
+#include "gpu/defs.h"
#include "gpu/tensor.h"
#include "gpu/array.h"
diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h
index bbdd4cee..f5fe6cf5 100644
--- a/src/graph/node_operators_binary.h
+++ b/src/graph/node_operators_binary.h
@@ -5,7 +5,7 @@
#include "graph/backend_gpu.h"
#include "graph/node.h"
#include "kernels/tensor_operators.h"
-#include "kernels/thrust_functions.h"
+#include "gpu/functions.h"
#ifdef CUDNN
diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h
index fcf1c290..05294bee 100644
--- a/src/graph/node_operators_unary.h
+++ b/src/graph/node_operators_unary.h
@@ -4,8 +4,8 @@
#include "graph/node.h"
#include "kernels/sparse.h"
#include "kernels/tensor_operators.h"
-#include "kernels/thrust_functions.h"
#include "tensors/tensor.h"
+#include "gpu/functions.h"
#ifdef CUDNN
@@ -51,6 +51,7 @@ public:
: UnaryNodeOp(a, args...), scalar_{scalar} {}
NodeOps forwardOps() {
+ using namespace functional;
return {NodeOp(Element(_1 = _2 + scalar_, val_, child(0)->val()))};
}
@@ -69,10 +70,12 @@ public:
: UnaryNodeOp(a, args...), scalar_{scalar} {}
NodeOps forwardOps() {
+ using namespace functional;
return {NodeOp(Element(_1 = scalar_ * _2, val_, child(0)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
return {NodeOp(Add(scalar_ * _1, child(0)->grad(), adj_))};
}
@@ -84,10 +87,12 @@ struct LogitNodeOp : public UnaryNodeOp {
LogitNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {NodeOp(Element(_1 = Sigma(_2), val_, child(0)->val()))};
+ using namespace functional;
+ return {NodeOp(Element(_1 = logit(_2), val_, child(0)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
return {NodeOp(Add(_1 * _2 * (1.0f - _2), child(0)->grad(), adj_, val_))};
}
@@ -145,13 +150,14 @@ struct TanhNodeOp : public NaryNodeOp {
}
NodeOps forwardOps() {
+ using namespace functional;
switch(children_.size()) {
- case 1: return {NodeOp(Element(_1 = Tanh(_2), val_, child(0)->val()))};
+ case 1: return {NodeOp(Element(_1 = tanh(_2), val_, child(0)->val()))};
case 2:
return {NodeOp(Element(
- _1 = Tanh(_2 + _3), val_, child(0)->val(), child(1)->val()))};
+ _1 = tanh(_2 + _3), val_, child(0)->val(), child(1)->val()))};
case 3:
- return {NodeOp(Element(_1 = Tanh(_2 + _3 + _4),
+ return {NodeOp(Element(_1 = tanh(_2 + _3 + _4),
val_,
child(0)->val(),
child(1)->val(),
@@ -164,13 +170,14 @@ struct TanhNodeOp : public NaryNodeOp {
child(1)->val(),
child(2)->val());
for(int i = 3; i < children_.size(); ++i)
- Element(_1 += _2, val_, child(i)->val());
- Element(_1 = Tanh(_1), val_);)
+ Element(_1 = _1 + _2, val_, child(i)->val());
+ Element(_1 = tanh(_1), val_);)
};
}
}
NodeOps backwardOps() {
+ using namespace functional;
NodeOps ops;
for(int i = 0; i < children_.size(); i++) {
ops.push_back(
@@ -205,6 +212,7 @@ struct ReLUNodeOp : public UnaryNodeOp {
NodeOps forwardOps() {
// f(x) = max(0, x)
+ using namespace functional;
return {NodeOp(Element(_1 = ReLU(_2),
val_, // _1 := f(x) to be calculated
child(0)->val() // _2 := x
@@ -212,6 +220,7 @@ struct ReLUNodeOp : public UnaryNodeOp {
}
NodeOps backwardOps() {
+ using namespace functional;
// dJ/dx += dJ/df * binarystep(x)
return {NodeOp(Add(_1 * ReLUback(_2),
child(0)->grad(), // dJ/dx
@@ -254,10 +263,12 @@ struct PReLUNodeOp : public UnaryNodeOp {
: UnaryNodeOp(args...), alpha_(alpha) {}
NodeOps forwardOps() {
+ using namespace functional;
return {NodeOp(Element(_1 = PReLU(_2, alpha_), val_, child(0)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
return {NodeOp(Add(
_1 * PReLUback(_2, alpha_), child(0)->grad(), adj_, child(0)->val()))};
}
@@ -283,12 +294,14 @@ struct SwishNodeOp : public UnaryNodeOp {
SwishNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {NodeOp(Element(_1 = _2 * Sigma(_2), val_, child(0)->val()))};
+ using namespace functional;
+ return {NodeOp(Element(_1 = _2 * logit(_2), val_, child(0)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
// dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) )
- return {NodeOp(Add(_1 * (_3 + Sigma(_2) * (1.f - _3)),
+ return {NodeOp(Add(_1 * (_3 + logit(_2) * (1.f - _3)),
child(0)->grad(), // dJ/dx
adj_, // _1 := dJ/df
child(0)->val(), // _2 := x
@@ -424,6 +437,7 @@ struct MeanNodeOp : public UnaryNodeOp {
: UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {}
NodeOps forwardOps() {
+ using namespace functional;
int left = child(0)->shape().elements() / val_->shape().elements();
float scale = 1.f / left;
@@ -431,6 +445,7 @@ struct MeanNodeOp : public UnaryNodeOp {
}
NodeOps backwardOps() {
+ using namespace functional;
int left = child(0)->shape().elements() / val_->shape().elements();
float scale = 1.f / left;
@@ -474,10 +489,12 @@ struct LogNodeOp : public UnaryNodeOp {
LogNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {NodeOp(Element(_1 = Log(_2), val_, child(0)->val()))};
+ using namespace functional;
+ return {NodeOp(Element(_1 = log(_2), val_, child(0)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
return {
NodeOp(Add(_1 * (1.f / _2), child(0)->grad(), adj_, child(0)->val()))};
}
@@ -490,11 +507,13 @@ struct ExpNodeOp : public UnaryNodeOp {
ExpNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
- return {NodeOp(Element(_1 = Exp(_2), val_, child(0)->val()))};
+ using namespace functional;
+ return {NodeOp(Element(_1 = exp(_2), val_, child(0)->val()))};
}
NodeOps backwardOps() {
- return {NodeOp(Add(_1 * Exp(_2), child(0)->grad(), adj_, child(0)->val()))};
+ using namespace functional;
+ return {NodeOp(Add(_1 * exp(_2), child(0)->grad(), adj_, child(0)->val()))};
}
const std::string type() { return "exp"; }
@@ -508,10 +527,12 @@ struct SqrtNodeOp : public UnaryNodeOp {
: UnaryNodeOp(a, args...), epsilon_(epsilon) {}
NodeOps forwardOps() {
- return {NodeOp(Element(_1 = Sqrt(_2 + epsilon_), val_, child(0)->val()))};
+ using namespace functional;
+ return {NodeOp(Element(_1 = sqrt(_2 + epsilon_), val_, child(0)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
return {NodeOp(Add(0.5f * (1.f / _1) * _2, child(0)->grad(), val_, adj_))};
}
@@ -545,10 +566,12 @@ struct SquareNodeOp : public UnaryNodeOp {
SquareNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
+ using namespace functional;
return {NodeOp(Element(_1 = _2 * _2, val_, child(0)->val()))};
}
NodeOps backwardOps() {
+ using namespace functional;
return {
NodeOp(Add(2.f * _1 * _2, child(0)->grad(), child(0)->val(), adj_))};
}
@@ -561,10 +584,14 @@ struct NegNodeOp : public UnaryNodeOp {
NegNodeOp(Args... args) : UnaryNodeOp(args...) {}
NodeOps forwardOps() {
+ using namespace functional;
return {NodeOp(Element(_1 = -_2, val_, child(0)->val()))};
}
- NodeOps backwardOps() { return {NodeOp(Add(-_1, child(0)->grad(), adj_))}; }
+ NodeOps backwardOps() {
+ using namespace functional;
+ return {NodeOp(Add(-_1, child(0)->grad(), adj_))};
+ }
const std::string type() { return "-"; }
};
@@ -974,13 +1001,13 @@ struct ShiftNodeOp : public UnaryNodeOp {
// void forward() {
// sparse::LfaForward(val_, child(0)->val(), child(1)->val(), lf_);
// // val = x + ln(p + eps)
-// Element(_1 = (Log(_1 + eps_) + _2), val_, child(0)->val());
+// Element(_1 = (log(_1 + eps_) + _2), val_, child(0)->val());
// }
//
// void backward() {
// Add(_1, child(0)->grad(), adj_);
// // adj' = adj / (p + eps) = adj / exp(val - x)
-// Element(_1 = _1 / Exp(_2 - _3), adj_, val_, child(0)->val());
+// Element(_1 = _1 / exp(_2 - _3), adj_, val_, child(0)->val());
// sparse::LfaBackward(child(1)->grad(), adj_, lf_);
// }
//
diff --git a/src/kernels/tensor_operators.h b/src/kernels/tensor_operators.h
index 6c7d7a03..54a14dd6 100644
--- a/src/kernels/tensor_operators.h
+++ b/src/kernels/tensor_operators.h
@@ -2,7 +2,6 @@
#include <cublas_v2.h>
#include <thrust/device_vector.h>
-#include <thrust/functional.h>
#include <thrust/host_vector.h>
#include <thrust/pair.h>
@@ -25,10 +24,9 @@ const int MAX_BLOCKS = 65535;
cublasHandle_t create_handle(size_t);
-template <size_t K, class Functor>
+template <size_t K, bool broadcast, class Functor>
__global__ void gElement(Functor functor,
- gpu::Array<gpu::Tensor<float>, K> tensors,
- bool broadcast) {
+ gpu::Array<gpu::Tensor<float>, K> tensors) {
int length = tensors[0].shape().elements();
gpu::Array<int, gpu::Shape::size()> dims;
@@ -37,16 +35,16 @@ __global__ void gElement(Functor functor,
for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) {
int index = bid + blockDim.x * blockIdx.x + threadIdx.x;
if(index < length) {
+
+ indices.fill(index);
+
if(broadcast) {
tensors[0].shape().dims(index, dims);
- indices[0] = index;
for(int i = 1; i < K; ++i)
indices[i] = tensors[i].shape().bindex(dims);
- tensors[0][index] = gpu::apply(functor, tensors, indices);
- }
- else {
- tensors[0][index] = gpu::apply(functor, tensors, index);
}
+
+ tensors[0][index] = gpu::apply(functor, tensors, indices);
}
}
}
@@ -66,7 +64,10 @@ void Element(Functor functor, Tensor out, Tensors ...tensors) {
for(int i = 1; i < K; ++i)
broadcast = broadcast || gTensors[0].shape() != gTensors[i].shape();
- gElement<<<blocks, threads>>>(functor, gTensors, broadcast);
+ if(broadcast)
+ gElement<K, true><<<blocks, threads>>>(functor, gTensors);
+ else
+ gElement<K, false><<<blocks, threads>>>(functor, gTensors);
}
void TransposeND(Tensor out, Tensor in, const std::vector<int>& vAxis);
diff --git a/src/kernels/thrust_functions.h b/src/kernels/thrust_functions.h
deleted file mode 100644
index 1d91fc38..00000000
--- a/src/kernels/thrust_functions.h
+++ /dev/null
@@ -1,204 +0,0 @@
-#pragma once
-
-#include <cublas_v2.h>
-#include <thrust/device_vector.h>
-#include <thrust/functional.h>
-#include <cmath>
-
-namespace thrust {
-namespace detail {
-namespace functional {
-
-template <typename T>
-struct unary_exp : public thrust::unary_function<T, T> {
- __host__ __device__ T operator()(const T &x) const { return expf(x); }
-};
-
-template <typename Eval>
-__host__ __device__ actor<composite<unary_operator<unary_exp>, actor<Eval>>>
-Exp(const actor<Eval> &_1) {
- return compose(unary_operator<unary_exp>(), _1);
-}
-
-template <typename T>
-struct unary_log : public thrust::unary_function<T, T> {
- __host__ __device__ T operator()(const T &x) const { return logf(x); }
-};
-
-template <typename Eval>
-__host__ __device__ actor<composite<unary_operator<unary_log>, actor<Eval>>>
-Log(const actor<Eval> &_1) {
- return compose(unary_operator<unary_log>(), _1);
-}
-
-template <typename T>
-struct unary_sigma : public thrust::unary_function<T, T> {
- __host__ __device__ T operator()(const T &x) const {
- if(x >= 0) {
- float z = expf(-x);
- return 1.0 / (1.0 + z);
- } else {
- float z = expf(x);
- return z / (1.0 + z);
- }
- }
-};
-
-template <typename Eval>
-__host__ __device__ actor<composite<unary_operator<unary_sigma>, actor<Eval>>>
-Sigma(const actor<Eval> &_1) {
- return compose(unary_operator<unary_sigma>(), _1);
-}
-
-template <typename T>
-struct unary_tanh : public thrust::unary_function<T, T> {
- __host__ __device__ T operator()(const T &x) const { return tanhf(x); }
-};
-
-template <typename Eval>
-__host__ __device__ actor<composite<unary_operator<unary_tanh>, actor<Eval>>>
-Tanh(const actor<Eval> &_1) {
- return compose(unary_operator<unary_tanh>(), _1);
-}
-
-template <typename T>
-struct unary_sqrt : public thrust::unary_function<T, T> {
- __host__ __device__ T operator()(const T &x) const { return sqrtf(x); }
-};
-
-template <typename Eval>
-__host__ __device__ actor<composite<unary_operator<unary_sqrt>, actor<Eval>>>
-Sqrt(const actor<Eval> &_1) {
- return compose(unary_operator<unary_sqrt>(), _1);
-}
-
-template <typename T1, typename T2>
-__host__ __device__
- actor<composite<binary_operator<thrust::maximum>, actor<T1>, actor<T2>>>
- Max(const actor<T1> &_1, const actor<T2> &_2) {
- return compose(
- binary_operator<thrust::maximum>(), make_actor(_1), make_actor(_2));
-}
-
-//*******************************************************************
-
-template <typename T>
-struct unary_relu : public thrust::unary_function<T, T> {
- __host__ __device__ T operator()(const T &x) const {
- return x > 0.0f ? x : 0.0f;
- }
-};
-
-template <typename Eval>
-__host__ __device__ actor<composite<unary_operator<unary_relu>, actor<Eval>>>
-ReLU(const actor<Eval> &_1) {
- return compose(unary_operator<unary_relu>(), _1);
-}
-
-template <typename T>
-struct unary_reluback : public thrust::unary_function<T, T> {
- __host__ __device__ T operator()(const T &x) const {
- return x > 0.0f ? 1.0f : 0.0f;
- }
-};
-
-template <typename Eval>
-__host__ __device__
- actor<composite<unary_operator<unary_reluback>, actor<Eval>>>
- ReLUback(const actor<Eval> &_1) {
- return compose(unary_operator<unary_reluback>(), _1);
-}
-
-//*******************************************************************
-
-template <typename T>
-struct binary_prelu : public thrust::binary_function<T, T, T> {
- __host__ __device__ T operator()(const T &x, const T &alpha) const {
- return x > 0.0f ? x : alpha * x;
- }
-};
-
-template <typename T1, typename T2>
-__host__ __device__ actor<composite<binary_operator<binary_prelu>,
- actor<T1>,
- typename as_actor<T2>::type>>
-PReLU(const actor<T1> &_1, const T2 &_2) {
- return compose(
- binary_operator<binary_prelu>(), make_actor(_1), make_actor(_2));
-}
-
-template <typename T>
-struct binary_preluback : public thrust::binary_function<T, T, T> {
- __host__ __device__ T operator()(const T &x, const T &alpha) const {
- return x > 0.0f ? 1.0f : alpha;
- }
-};
-
-template <typename T1, typename T2>
-__host__ __device__ actor<composite<binary_operator<binary_preluback>,
- actor<T1>,
- typename as_actor<T2>::type>>
-PReLUback(const actor<T1> &_1, const T2 &_2) {
- return compose(
- binary_operator<binary_preluback>(), make_actor(_1), make_actor(_2));
-}
-
-//*******************************************************************
-
-template <typename T>
-__host__ __device__ int sgn(T val) {
- return (float(0) < val) - (val < float(0));
-}
-
-template <typename T>
-struct binary_clip : public thrust::binary_function<T, T, T> {
- __host__ __device__ T operator()(const T &x, const T &y) const {
- return abs(x) >= y ? sgn(x) * y : x;
- }
-};
-
-template <typename T1, typename T2>
-__host__ __device__ actor<composite<binary_operator<binary_clip>,
- actor<T1>,
- typename as_actor<T2>::type>>
-Clip(const actor<T1> &_1, const T2 &_2) {
- return compose(
- binary_operator<binary_clip>(), make_actor(_1), make_actor(_2));
-}
-
-template <typename T>
-struct binary_prune : public thrust::binary_function<T, T, T> {
- __host__ __device__ T operator()(const T &x, const T &eps) const {
- return abs(x) >= eps ? x : 0;
- }
-};
-
-template <typename T1, typename T2>
-__host__ __device__ actor<composite<binary_operator<binary_prune>,
- actor<T1>,
- typename as_actor<T2>::type>>
-Prune(const actor<T1> &_1, const T2 &_2) {
- return compose(
- binary_operator<binary_prune>(), make_actor(_1), make_actor(_2));
-}
-
-template <typename T>
-struct binary_pow : public thrust::binary_function<T, T, T> {
- __host__ __device__ T operator()(const T &x, const T &y) const {
- float tx = x;
- if(y == (int)y && (int)y % 2 == 0)
- tx = abs(x);
- return powf(tx, y);
- }
-};
-
-template <typename T1, typename T2>
-__host__ __device__ actor<composite<binary_operator<binary_pow>,
- actor<T1>,
- typename as_actor<T2>::type>>
-Pow(const actor<T1> &_1, const T2 &_2) {
- return compose(binary_operator<binary_pow>(), make_actor(_1), make_actor(_2));
-}
-}
-}
-}
diff --git a/src/optimizers/clippers.cu b/src/optimizers/clippers.cu
index 3f3eeb6b..ee81c02f 100644
--- a/src/optimizers/clippers.cu
+++ b/src/optimizers/clippers.cu
@@ -1,14 +1,16 @@
#include "clippers.h"
#include "kernels/tensor_operators.h"
-#include "kernels/thrust_functions.h"
+#include "gpu/functions.h"
namespace marian {
void Elementwise::clip(Tensor t) {
- Element(_1 = Clip(_1, c_), t);
+ using namespace functional;
+ Element(_1 = functional::clip(_1, c_), t);
}
void Norm::clip(Tensor t) {
+ using namespace functional;
float l2Norm = L2Norm(t);
if(l2Norm >= c_)
Element(_1 = (c_ / l2Norm) * _1, t);
diff --git a/src/optimizers/optimizers.cu b/src/optimizers/optimizers.cu
index dc5ed976..9c4a2ce7 100644
--- a/src/optimizers/optimizers.cu
+++ b/src/optimizers/optimizers.cu
@@ -1,11 +1,12 @@
#include "optimizers.h"
#include "kernels/tensor_operators.h"
-#include "kernels/thrust_functions.h"
+#include "gpu/functions.h"
namespace marian {
void Sgd::updateImpl(Tensor params, Tensor grads) {
- Element(_1 -= (multiplyFactor_ * eta_) * _2, params, grads);
+ using namespace functional;
+ Element(_1 = _1 - (multiplyFactor_ * eta_) * _2, params, grads);
cudaStreamSynchronize(0);
}
@@ -21,9 +22,11 @@ void Adagrad::updateImpl(Tensor params, Tensor grads) {
gt_->set(0);
}
- Element(_1 += (_2 * _2), gt_, grads);
+ using namespace functional;
- Element(_1 -= ((multiplyFactor_ * eta_) / (Sqrt(_2) + eps_)) * _3,
+ Element(_1 = _1 + (_2 * _2), gt_, grads);
+
+ Element(_1 = _1 - ((multiplyFactor_ * eta_) / (sqrt(_2) + eps_)) * _3,
params,
gt_,
grads);
@@ -55,11 +58,13 @@ void Adam::updateImpl(Tensor params, Tensor grads) {
float denom1 = 1 - std::pow(beta1_, t_);
float denom2 = 1 - std::pow(beta2_, t_);
+ using namespace functional;
+
Element(_1 = (beta1_ * _1) + ((1 - beta1_) * _2), mt_, grads);
Element(_1 = (beta2_ * _1) + ((1 - beta2_) * (_2 * _2)), vt_, grads);
- Element(_1 -= (multiplyFactor_ * eta_) * (_2 / denom1)
- / (Sqrt(_3 / denom2) + eps_),
+ Element(_1 = _1 - (multiplyFactor_ * eta_) * (_2 / denom1)
+ / (sqrt(_3 / denom2) + eps_),
params,
mt_,
vt_);
diff --git a/src/tests/tensor_test.cu b/src/tests/tensor_test.cu
index 42bce1f1..ba9be80c 100644
--- a/src/tests/tensor_test.cu
+++ b/src/tests/tensor_test.cu
@@ -1,31 +1,44 @@
-#include <boost/timer/timer.hpp>
#include <iostream>
-#include <map>
-#include "marian.h"
-#include "rnn/rnn.h"
+#include "gpu/placeholders.h"
+#include "gpu/functions.h"
int main(int argc, char** argv) {
- using namespace marian;
- using namespace keywords;
- auto graph = New<ExpressionGraph>();
- graph->setDevice(0);
+ using namespace marian::functional;
- auto in1 = graph->constant({2, 2, 4}, init=inits::from_value(1));
- auto in2 = graph->constant({2, 2, 4}, init=inits::from_value(2));
- auto in3 = graph->constant({2, 2, 4}, init=inits::from_value(3));
- auto in4 = graph->constant({2, 2, 4}, init=inits::from_value(4));
+ auto func = _1 = tanh(_2) * 3;
- auto out = concatenate({in1, in2, in3, in4}, axis=1);
+ float z;
+ std::cerr << func(z, 2.f) << " " << sizeof(func) << std::endl;
+ std::cerr << z << " " << std::endl;
+ return 0;
+}
- debug(in1, "in1");
- debug(in2, "in2");
- debug(in3, "in3");
- debug(in4, "in4");
- debug(out, "out");
+/*
- graph->forward();
+struct SwishNodeOp : public UnaryNodeOp {
+ template <typename... Args>
+ SwishNodeOp(Args... args) : UnaryNodeOp(args...) {}
- return 0;
-}
+ NodeOps forwardOps() {
+
+ using namespace gpu::m;
+ ref<1> x;
+ auto swish = x * logit(x);
+
+ return {NodeOp(Element(swish, val_, child(0)->val()))};
+ }
+
+ NodeOps backwardOps() {
+
+ using namespace gpu::m;
+ ref<0> dJdf;
+ ref<1> x;
+ ref<2> f;
+ auto dJdx = dJdf * (f + logit(x) * (1 - f));
+
+ return {NodeOp(Add(dJdx, child(0)->grad(), adj_, child(0)->val(), val_))};
+ }
+
+ */ \ No newline at end of file
diff --git a/src/training/graph_group_async.cu b/src/training/graph_group_async.cu
index d9511501..3f63eed2 100644
--- a/src/training/graph_group_async.cu
+++ b/src/training/graph_group_async.cu
@@ -1,7 +1,7 @@
#include "training/graph_group_async.h"
#include "kernels/tensor_operators.h"
-#include "kernels/thrust_functions.h"
+#include "gpu/functions.h"
namespace marian {
@@ -72,6 +72,7 @@ void AsyncGraphGroup::updateMovingAverage(Tensor paramsAvg,
Tensor params,
size_t batches) {
float decay = min(mvDecay_, (float)(batches + 1) / (float)(batches + 10));
+ using namespace functional;
Element(_1 = (decay * _1) + ((1.f - decay) * _2), paramsAvg, params);
}