diff options
-rw-r--r-- | CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/gpu/array.h | 9 | ||||
-rw-r--r-- | src/gpu/defs.h | 22 | ||||
-rw-r--r-- | src/gpu/functions.h | 98 | ||||
-rw-r--r-- | src/gpu/placeholders.h | 96 | ||||
-rw-r--r-- | src/gpu/tmp.h | 1 | ||||
-rw-r--r-- | src/graph/node_operators_binary.h | 2 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 59 | ||||
-rw-r--r-- | src/kernels/tensor_operators.h | 21 | ||||
-rw-r--r-- | src/kernels/thrust_functions.h | 204 | ||||
-rw-r--r-- | src/optimizers/clippers.cu | 6 | ||||
-rw-r--r-- | src/optimizers/optimizers.cu | 17 | ||||
-rw-r--r-- | src/tests/tensor_test.cu | 55 | ||||
-rw-r--r-- | src/training/graph_group_async.cu | 3 |
15 files changed, 326 insertions, 271 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index bb471b98..fa502468 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,7 +27,7 @@ set(CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS_RELEASE}) # Find packages find_package(CUDA "8.0" REQUIRED) if(CUDA_FOUND) - set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY}) + set(EXT_LIBS ${EXT_LIBS} ${CUDA_curand_LIBRARY} ${CUDA_cusparse_LIBRARY} tcmalloc_minimal) endif(CUDA_FOUND) if (CMAKE_BUILD_TYPE STREQUAL "Debug") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7c2bc3e8..a177c0de 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -16,7 +16,7 @@ cuda_add_library(marian tensors/memory_piece.cu kernels/tensor_operators.cu kernels/dropout.cu - kernels/sparse.cu +# kernels/sparse.cu layers/param_initializers.cu layers/generic.cpp layers/guided_alignment.cpp diff --git a/src/gpu/array.h b/src/gpu/array.h index 65d80bfa..8981e8c7 100644 --- a/src/gpu/array.h +++ b/src/gpu/array.h @@ -1,18 +1,11 @@ #pragma once -#include <cuda.h> +#include "gpu/defs.h" namespace marian { namespace gpu { -#define __H__ __host__ -#define __D__ __device__ -#define __HI__ __host__ inline -#define __DI__ __device__ inline -#define __HD__ __host__ __device__ -#define __HDI__ __host__ __device__ inline - template <typename T, size_t N> struct Array { typedef T value_type; diff --git a/src/gpu/defs.h b/src/gpu/defs.h new file mode 100644 index 00000000..6fab8f8e --- /dev/null +++ b/src/gpu/defs.h @@ -0,0 +1,22 @@ +#pragma once + +#ifdef __CUDA_ARCH__ + +#include <cuda.h> +#define __H__ __host__ +#define __D__ __device__ +#define __HI__ __host__ inline +#define __DI__ __device__ inline +#define __HD__ __host__ __device__ +#define __HDI__ __host__ __device__ inline + +#else + +#define __H__ +#define __D__ +#define __HI__ inline +#define __DI__ inline +#define __HD__ +#define __HDI__ inline + +#endif
\ No newline at end of file diff --git a/src/gpu/functions.h b/src/gpu/functions.h new file mode 100644 index 00000000..64f7c602 --- /dev/null +++ b/src/gpu/functions.h @@ -0,0 +1,98 @@ +#pragma once + +#include "gpu/defs.h" +#include "gpu/placeholders.h" + +namespace marian { + namespace functional { + + template <typename Function, typename X> + struct UnaryFunctor { + X x; + + template <class Arg> + __HD__ UnaryFunctor(Arg a) : x(a) {} + + template <typename ...Args> + __HDI__ float operator()(Args&&... args) { + return Function::apply(x(args...)); + } + }; + + template <class Function, class X, class Y> + struct BinaryFunctor { + X x; + Y y; + + template <class Arg1, class Arg2> + __HD__ BinaryFunctor(Arg1 arg1, Arg2 arg2) : x(arg1), y(arg2) {} + + template <typename ...Args> + __HDI__ float operator()(Args&&... args) { + return Function::apply(x(args...), y(args...)); + } + }; + + #define UNARY(name, name2, func) \ + namespace elem { \ + struct name { \ + __HDI__ static float apply(float x) { return func; } \ + }; \ + }\ + template <class X> using name = UnaryFunctor<elem::name, X>;\ + template <typename X>\ + __HDI__ Op<name<X>> name2(Op<X> x) {\ + return Op<name<X>>(name<X>(x.f));\ + } + + #define BINARY(name, name2, func) \ + namespace elem { \ + struct name { \ + __HDI__ static float apply(float x, float y) { return func; } \ + }; \ + }\ + template <class X, class Y> using name = BinaryFunctor<elem::name, X, Y>;\ + template <typename X, typename Y>\ + __HDI__ Op<name<X, Y>> name2(Op<X> x, Op<Y> y) {\ + return Op<name<X, Y>>(name<X, Y>(x.f, y.f));\ + }\ + template <typename X>\ + __HDI__ Op<name<X, C>> name2(Op<X> x, float y) {\ + return name2(x, Op<C>(y));\ + }\ + template <typename Y>\ + __HDI__ Op<name<C, Y>> name2(float x, Op<Y> y) {\ + return name2(Op<C>(x), y);\ + } + + UNARY(Tanh, tanh, tanhf(x)); + UNARY(Sin, sin, sinf(x)); + UNARY(Cos, cos, cosf(x)); + UNARY(Tan, tan, tanf(x)); + UNARY(Log, log, logf(x)); + UNARY(Exp, exp, expf(x)); + UNARY(Abs, abs, fabs(x)); + UNARY(Sqrt, sqrt, sqrtf(x)); + UNARY(Neg, operator-, -x); + UNARY(Logit, logit, x > 0 ? (1.f / (1.f + expf(-x))) : (expf(x) / (1.f + expf(x)))); + + BINARY(Plus, operator+, x + y); + BINARY(Minus, operator-, x - y); + BINARY(Mult, operator*, x * y); + BINARY(Div, operator/, x / y); + BINARY(Pow, pow, pow(x, y)); + + template <typename T> + __HDI__ T sgn(T val) { + return (float(0) < val) - (val < float(0)); + } + + BINARY(Clip, clip, fabs(x) >= y ? sgn(x) * y : x); + + UNARY(sReLU, ReLU, x > 0.f ? x : 0.f); + UNARY(sReLUBack, ReLUback, x > 0.f ? 1.f : 0.f); + BINARY(sPReLU, PReLU, x > 0.f ? x : x * y); + BINARY(sPReLUBack, PReLUback, x > 0.f ? 1.f : y); + + } +}
\ No newline at end of file diff --git a/src/gpu/placeholders.h b/src/gpu/placeholders.h new file mode 100644 index 00000000..7e3e72e0 --- /dev/null +++ b/src/gpu/placeholders.h @@ -0,0 +1,96 @@ +#pragma once + +#include "gpu/defs.h" + +namespace marian { + namespace functional { + + template <int N> + struct Select { + template <typename T, typename ...Args> + __HDI__ static auto apply(T&& arg, Args&&... args) -> decltype(Select<N-1>::apply(args...)) { + return Select<N-1>::apply(args...); + } + }; + + template <> + struct Select<0> { + template <typename T, typename ...Args> + __HDI__ static T apply(T&& arg, Args&&... args) { + return arg; + } + }; + + template <int N> + struct X { + + template <typename ...Args> + __HDI__ float& operator()(Args&&... args) { + return Select<N-1>::apply(args...); + } + }; + + struct C { + float value; + + __HD__ C(const C& c) : value(c.value) {} + __HD__ C(float f) : value(f) {} + + template <typename ...Args> + __HDI__ float& operator()(Args&&... args) { return value; } + }; + + + template <class X, class Y> + struct Assign { + X x; + Y y; + + template <class Arg1, class Arg2> + __HD__ Assign(Arg1&& arg1, Arg2&& arg2) : x(arg1), y(arg2) {} + + template <typename ...Args> + __HDI__ float operator()(Args&&... args) { + return x(args...) = y(args...); + } + }; + + template <class F> + struct Op { + F f; + + __HD__ Op() {} + + template <class A> + __HD__ Op(A a) : f(a) {} + + template <class X> + __HD__ Op<Assign<F, X>> operator=(Op<X> x) { + return Op<Assign<F, X>>(Assign<F, X>(f, x.f)); + } + + __HD__ Op<Assign<F, C>> operator=(float x) { + return Op<Assign<F, C>>(Assign<F, C>(f, C(x))); + } + + template <typename ...Args> + __HDI__ float operator()(Args&&... args) { + return f(args...); + } + }; + + template <int N> + using ref = Op<X<N>>; + + static ref<1> _1; + static ref<2> _2; + static ref<3> _3; + static ref<4> _4; + static ref<5> _5; + static ref<6> _6; + static ref<7> _7; + static ref<8> _8; + static ref<9> _9; + + } +}
\ No newline at end of file diff --git a/src/gpu/tmp.h b/src/gpu/tmp.h index c6685d05..6a03fbd1 100644 --- a/src/gpu/tmp.h +++ b/src/gpu/tmp.h @@ -1,5 +1,6 @@ #pragma once +#include "gpu/defs.h" #include "gpu/tensor.h" #include "gpu/array.h" diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index bbdd4cee..f5fe6cf5 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -5,7 +5,7 @@ #include "graph/backend_gpu.h" #include "graph/node.h" #include "kernels/tensor_operators.h" -#include "kernels/thrust_functions.h" +#include "gpu/functions.h" #ifdef CUDNN diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index fcf1c290..05294bee 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -4,8 +4,8 @@ #include "graph/node.h" #include "kernels/sparse.h" #include "kernels/tensor_operators.h" -#include "kernels/thrust_functions.h" #include "tensors/tensor.h" +#include "gpu/functions.h" #ifdef CUDNN @@ -51,6 +51,7 @@ public: : UnaryNodeOp(a, args...), scalar_{scalar} {} NodeOps forwardOps() { + using namespace functional; return {NodeOp(Element(_1 = _2 + scalar_, val_, child(0)->val()))}; } @@ -69,10 +70,12 @@ public: : UnaryNodeOp(a, args...), scalar_{scalar} {} NodeOps forwardOps() { + using namespace functional; return {NodeOp(Element(_1 = scalar_ * _2, val_, child(0)->val()))}; } NodeOps backwardOps() { + using namespace functional; return {NodeOp(Add(scalar_ * _1, child(0)->grad(), adj_))}; } @@ -84,10 +87,12 @@ struct LogitNodeOp : public UnaryNodeOp { LogitNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { - return {NodeOp(Element(_1 = Sigma(_2), val_, child(0)->val()))}; + using namespace functional; + return {NodeOp(Element(_1 = logit(_2), val_, child(0)->val()))}; } NodeOps backwardOps() { + using namespace functional; return {NodeOp(Add(_1 * _2 * (1.0f - _2), child(0)->grad(), adj_, val_))}; } @@ -145,13 +150,14 @@ struct TanhNodeOp : public NaryNodeOp { } NodeOps forwardOps() { + using namespace functional; switch(children_.size()) { - case 1: return {NodeOp(Element(_1 = Tanh(_2), val_, child(0)->val()))}; + case 1: return {NodeOp(Element(_1 = tanh(_2), val_, child(0)->val()))}; case 2: return {NodeOp(Element( - _1 = Tanh(_2 + _3), val_, child(0)->val(), child(1)->val()))}; + _1 = tanh(_2 + _3), val_, child(0)->val(), child(1)->val()))}; case 3: - return {NodeOp(Element(_1 = Tanh(_2 + _3 + _4), + return {NodeOp(Element(_1 = tanh(_2 + _3 + _4), val_, child(0)->val(), child(1)->val(), @@ -164,13 +170,14 @@ struct TanhNodeOp : public NaryNodeOp { child(1)->val(), child(2)->val()); for(int i = 3; i < children_.size(); ++i) - Element(_1 += _2, val_, child(i)->val()); - Element(_1 = Tanh(_1), val_);) + Element(_1 = _1 + _2, val_, child(i)->val()); + Element(_1 = tanh(_1), val_);) }; } } NodeOps backwardOps() { + using namespace functional; NodeOps ops; for(int i = 0; i < children_.size(); i++) { ops.push_back( @@ -205,6 +212,7 @@ struct ReLUNodeOp : public UnaryNodeOp { NodeOps forwardOps() { // f(x) = max(0, x) + using namespace functional; return {NodeOp(Element(_1 = ReLU(_2), val_, // _1 := f(x) to be calculated child(0)->val() // _2 := x @@ -212,6 +220,7 @@ struct ReLUNodeOp : public UnaryNodeOp { } NodeOps backwardOps() { + using namespace functional; // dJ/dx += dJ/df * binarystep(x) return {NodeOp(Add(_1 * ReLUback(_2), child(0)->grad(), // dJ/dx @@ -254,10 +263,12 @@ struct PReLUNodeOp : public UnaryNodeOp { : UnaryNodeOp(args...), alpha_(alpha) {} NodeOps forwardOps() { + using namespace functional; return {NodeOp(Element(_1 = PReLU(_2, alpha_), val_, child(0)->val()))}; } NodeOps backwardOps() { + using namespace functional; return {NodeOp(Add( _1 * PReLUback(_2, alpha_), child(0)->grad(), adj_, child(0)->val()))}; } @@ -283,12 +294,14 @@ struct SwishNodeOp : public UnaryNodeOp { SwishNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { - return {NodeOp(Element(_1 = _2 * Sigma(_2), val_, child(0)->val()))}; + using namespace functional; + return {NodeOp(Element(_1 = _2 * logit(_2), val_, child(0)->val()))}; } NodeOps backwardOps() { + using namespace functional; // dJ/dx += dJ/df * ( f(x) + sigma(x) * (1 - f(x)) ) - return {NodeOp(Add(_1 * (_3 + Sigma(_2) * (1.f - _3)), + return {NodeOp(Add(_1 * (_3 + logit(_2) * (1.f - _3)), child(0)->grad(), // dJ/dx adj_, // _1 := dJ/df child(0)->val(), // _2 := x @@ -424,6 +437,7 @@ struct MeanNodeOp : public UnaryNodeOp { : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {} NodeOps forwardOps() { + using namespace functional; int left = child(0)->shape().elements() / val_->shape().elements(); float scale = 1.f / left; @@ -431,6 +445,7 @@ struct MeanNodeOp : public UnaryNodeOp { } NodeOps backwardOps() { + using namespace functional; int left = child(0)->shape().elements() / val_->shape().elements(); float scale = 1.f / left; @@ -474,10 +489,12 @@ struct LogNodeOp : public UnaryNodeOp { LogNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { - return {NodeOp(Element(_1 = Log(_2), val_, child(0)->val()))}; + using namespace functional; + return {NodeOp(Element(_1 = log(_2), val_, child(0)->val()))}; } NodeOps backwardOps() { + using namespace functional; return { NodeOp(Add(_1 * (1.f / _2), child(0)->grad(), adj_, child(0)->val()))}; } @@ -490,11 +507,13 @@ struct ExpNodeOp : public UnaryNodeOp { ExpNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { - return {NodeOp(Element(_1 = Exp(_2), val_, child(0)->val()))}; + using namespace functional; + return {NodeOp(Element(_1 = exp(_2), val_, child(0)->val()))}; } NodeOps backwardOps() { - return {NodeOp(Add(_1 * Exp(_2), child(0)->grad(), adj_, child(0)->val()))}; + using namespace functional; + return {NodeOp(Add(_1 * exp(_2), child(0)->grad(), adj_, child(0)->val()))}; } const std::string type() { return "exp"; } @@ -508,10 +527,12 @@ struct SqrtNodeOp : public UnaryNodeOp { : UnaryNodeOp(a, args...), epsilon_(epsilon) {} NodeOps forwardOps() { - return {NodeOp(Element(_1 = Sqrt(_2 + epsilon_), val_, child(0)->val()))}; + using namespace functional; + return {NodeOp(Element(_1 = sqrt(_2 + epsilon_), val_, child(0)->val()))}; } NodeOps backwardOps() { + using namespace functional; return {NodeOp(Add(0.5f * (1.f / _1) * _2, child(0)->grad(), val_, adj_))}; } @@ -545,10 +566,12 @@ struct SquareNodeOp : public UnaryNodeOp { SquareNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { + using namespace functional; return {NodeOp(Element(_1 = _2 * _2, val_, child(0)->val()))}; } NodeOps backwardOps() { + using namespace functional; return { NodeOp(Add(2.f * _1 * _2, child(0)->grad(), child(0)->val(), adj_))}; } @@ -561,10 +584,14 @@ struct NegNodeOp : public UnaryNodeOp { NegNodeOp(Args... args) : UnaryNodeOp(args...) {} NodeOps forwardOps() { + using namespace functional; return {NodeOp(Element(_1 = -_2, val_, child(0)->val()))}; } - NodeOps backwardOps() { return {NodeOp(Add(-_1, child(0)->grad(), adj_))}; } + NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(-_1, child(0)->grad(), adj_))}; + } const std::string type() { return "-"; } }; @@ -974,13 +1001,13 @@ struct ShiftNodeOp : public UnaryNodeOp { // void forward() { // sparse::LfaForward(val_, child(0)->val(), child(1)->val(), lf_); // // val = x + ln(p + eps) -// Element(_1 = (Log(_1 + eps_) + _2), val_, child(0)->val()); +// Element(_1 = (log(_1 + eps_) + _2), val_, child(0)->val()); // } // // void backward() { // Add(_1, child(0)->grad(), adj_); // // adj' = adj / (p + eps) = adj / exp(val - x) -// Element(_1 = _1 / Exp(_2 - _3), adj_, val_, child(0)->val()); +// Element(_1 = _1 / exp(_2 - _3), adj_, val_, child(0)->val()); // sparse::LfaBackward(child(1)->grad(), adj_, lf_); // } // diff --git a/src/kernels/tensor_operators.h b/src/kernels/tensor_operators.h index 6c7d7a03..54a14dd6 100644 --- a/src/kernels/tensor_operators.h +++ b/src/kernels/tensor_operators.h @@ -2,7 +2,6 @@ #include <cublas_v2.h> #include <thrust/device_vector.h> -#include <thrust/functional.h> #include <thrust/host_vector.h> #include <thrust/pair.h> @@ -25,10 +24,9 @@ const int MAX_BLOCKS = 65535; cublasHandle_t create_handle(size_t); -template <size_t K, class Functor> +template <size_t K, bool broadcast, class Functor> __global__ void gElement(Functor functor, - gpu::Array<gpu::Tensor<float>, K> tensors, - bool broadcast) { + gpu::Array<gpu::Tensor<float>, K> tensors) { int length = tensors[0].shape().elements(); gpu::Array<int, gpu::Shape::size()> dims; @@ -37,16 +35,16 @@ __global__ void gElement(Functor functor, for(int bid = 0; bid < length; bid += blockDim.x * gridDim.x) { int index = bid + blockDim.x * blockIdx.x + threadIdx.x; if(index < length) { + + indices.fill(index); + if(broadcast) { tensors[0].shape().dims(index, dims); - indices[0] = index; for(int i = 1; i < K; ++i) indices[i] = tensors[i].shape().bindex(dims); - tensors[0][index] = gpu::apply(functor, tensors, indices); - } - else { - tensors[0][index] = gpu::apply(functor, tensors, index); } + + tensors[0][index] = gpu::apply(functor, tensors, indices); } } } @@ -66,7 +64,10 @@ void Element(Functor functor, Tensor out, Tensors ...tensors) { for(int i = 1; i < K; ++i) broadcast = broadcast || gTensors[0].shape() != gTensors[i].shape(); - gElement<<<blocks, threads>>>(functor, gTensors, broadcast); + if(broadcast) + gElement<K, true><<<blocks, threads>>>(functor, gTensors); + else + gElement<K, false><<<blocks, threads>>>(functor, gTensors); } void TransposeND(Tensor out, Tensor in, const std::vector<int>& vAxis); diff --git a/src/kernels/thrust_functions.h b/src/kernels/thrust_functions.h deleted file mode 100644 index 1d91fc38..00000000 --- a/src/kernels/thrust_functions.h +++ /dev/null @@ -1,204 +0,0 @@ -#pragma once - -#include <cublas_v2.h> -#include <thrust/device_vector.h> -#include <thrust/functional.h> -#include <cmath> - -namespace thrust { -namespace detail { -namespace functional { - -template <typename T> -struct unary_exp : public thrust::unary_function<T, T> { - __host__ __device__ T operator()(const T &x) const { return expf(x); } -}; - -template <typename Eval> -__host__ __device__ actor<composite<unary_operator<unary_exp>, actor<Eval>>> -Exp(const actor<Eval> &_1) { - return compose(unary_operator<unary_exp>(), _1); -} - -template <typename T> -struct unary_log : public thrust::unary_function<T, T> { - __host__ __device__ T operator()(const T &x) const { return logf(x); } -}; - -template <typename Eval> -__host__ __device__ actor<composite<unary_operator<unary_log>, actor<Eval>>> -Log(const actor<Eval> &_1) { - return compose(unary_operator<unary_log>(), _1); -} - -template <typename T> -struct unary_sigma : public thrust::unary_function<T, T> { - __host__ __device__ T operator()(const T &x) const { - if(x >= 0) { - float z = expf(-x); - return 1.0 / (1.0 + z); - } else { - float z = expf(x); - return z / (1.0 + z); - } - } -}; - -template <typename Eval> -__host__ __device__ actor<composite<unary_operator<unary_sigma>, actor<Eval>>> -Sigma(const actor<Eval> &_1) { - return compose(unary_operator<unary_sigma>(), _1); -} - -template <typename T> -struct unary_tanh : public thrust::unary_function<T, T> { - __host__ __device__ T operator()(const T &x) const { return tanhf(x); } -}; - -template <typename Eval> -__host__ __device__ actor<composite<unary_operator<unary_tanh>, actor<Eval>>> -Tanh(const actor<Eval> &_1) { - return compose(unary_operator<unary_tanh>(), _1); -} - -template <typename T> -struct unary_sqrt : public thrust::unary_function<T, T> { - __host__ __device__ T operator()(const T &x) const { return sqrtf(x); } -}; - -template <typename Eval> -__host__ __device__ actor<composite<unary_operator<unary_sqrt>, actor<Eval>>> -Sqrt(const actor<Eval> &_1) { - return compose(unary_operator<unary_sqrt>(), _1); -} - -template <typename T1, typename T2> -__host__ __device__ - actor<composite<binary_operator<thrust::maximum>, actor<T1>, actor<T2>>> - Max(const actor<T1> &_1, const actor<T2> &_2) { - return compose( - binary_operator<thrust::maximum>(), make_actor(_1), make_actor(_2)); -} - -//******************************************************************* - -template <typename T> -struct unary_relu : public thrust::unary_function<T, T> { - __host__ __device__ T operator()(const T &x) const { - return x > 0.0f ? x : 0.0f; - } -}; - -template <typename Eval> -__host__ __device__ actor<composite<unary_operator<unary_relu>, actor<Eval>>> -ReLU(const actor<Eval> &_1) { - return compose(unary_operator<unary_relu>(), _1); -} - -template <typename T> -struct unary_reluback : public thrust::unary_function<T, T> { - __host__ __device__ T operator()(const T &x) const { - return x > 0.0f ? 1.0f : 0.0f; - } -}; - -template <typename Eval> -__host__ __device__ - actor<composite<unary_operator<unary_reluback>, actor<Eval>>> - ReLUback(const actor<Eval> &_1) { - return compose(unary_operator<unary_reluback>(), _1); -} - -//******************************************************************* - -template <typename T> -struct binary_prelu : public thrust::binary_function<T, T, T> { - __host__ __device__ T operator()(const T &x, const T &alpha) const { - return x > 0.0f ? x : alpha * x; - } -}; - -template <typename T1, typename T2> -__host__ __device__ actor<composite<binary_operator<binary_prelu>, - actor<T1>, - typename as_actor<T2>::type>> -PReLU(const actor<T1> &_1, const T2 &_2) { - return compose( - binary_operator<binary_prelu>(), make_actor(_1), make_actor(_2)); -} - -template <typename T> -struct binary_preluback : public thrust::binary_function<T, T, T> { - __host__ __device__ T operator()(const T &x, const T &alpha) const { - return x > 0.0f ? 1.0f : alpha; - } -}; - -template <typename T1, typename T2> -__host__ __device__ actor<composite<binary_operator<binary_preluback>, - actor<T1>, - typename as_actor<T2>::type>> -PReLUback(const actor<T1> &_1, const T2 &_2) { - return compose( - binary_operator<binary_preluback>(), make_actor(_1), make_actor(_2)); -} - -//******************************************************************* - -template <typename T> -__host__ __device__ int sgn(T val) { - return (float(0) < val) - (val < float(0)); -} - -template <typename T> -struct binary_clip : public thrust::binary_function<T, T, T> { - __host__ __device__ T operator()(const T &x, const T &y) const { - return abs(x) >= y ? sgn(x) * y : x; - } -}; - -template <typename T1, typename T2> -__host__ __device__ actor<composite<binary_operator<binary_clip>, - actor<T1>, - typename as_actor<T2>::type>> -Clip(const actor<T1> &_1, const T2 &_2) { - return compose( - binary_operator<binary_clip>(), make_actor(_1), make_actor(_2)); -} - -template <typename T> -struct binary_prune : public thrust::binary_function<T, T, T> { - __host__ __device__ T operator()(const T &x, const T &eps) const { - return abs(x) >= eps ? x : 0; - } -}; - -template <typename T1, typename T2> -__host__ __device__ actor<composite<binary_operator<binary_prune>, - actor<T1>, - typename as_actor<T2>::type>> -Prune(const actor<T1> &_1, const T2 &_2) { - return compose( - binary_operator<binary_prune>(), make_actor(_1), make_actor(_2)); -} - -template <typename T> -struct binary_pow : public thrust::binary_function<T, T, T> { - __host__ __device__ T operator()(const T &x, const T &y) const { - float tx = x; - if(y == (int)y && (int)y % 2 == 0) - tx = abs(x); - return powf(tx, y); - } -}; - -template <typename T1, typename T2> -__host__ __device__ actor<composite<binary_operator<binary_pow>, - actor<T1>, - typename as_actor<T2>::type>> -Pow(const actor<T1> &_1, const T2 &_2) { - return compose(binary_operator<binary_pow>(), make_actor(_1), make_actor(_2)); -} -} -} -} diff --git a/src/optimizers/clippers.cu b/src/optimizers/clippers.cu index 3f3eeb6b..ee81c02f 100644 --- a/src/optimizers/clippers.cu +++ b/src/optimizers/clippers.cu @@ -1,14 +1,16 @@ #include "clippers.h" #include "kernels/tensor_operators.h" -#include "kernels/thrust_functions.h" +#include "gpu/functions.h" namespace marian { void Elementwise::clip(Tensor t) { - Element(_1 = Clip(_1, c_), t); + using namespace functional; + Element(_1 = functional::clip(_1, c_), t); } void Norm::clip(Tensor t) { + using namespace functional; float l2Norm = L2Norm(t); if(l2Norm >= c_) Element(_1 = (c_ / l2Norm) * _1, t); diff --git a/src/optimizers/optimizers.cu b/src/optimizers/optimizers.cu index dc5ed976..9c4a2ce7 100644 --- a/src/optimizers/optimizers.cu +++ b/src/optimizers/optimizers.cu @@ -1,11 +1,12 @@ #include "optimizers.h" #include "kernels/tensor_operators.h" -#include "kernels/thrust_functions.h" +#include "gpu/functions.h" namespace marian { void Sgd::updateImpl(Tensor params, Tensor grads) { - Element(_1 -= (multiplyFactor_ * eta_) * _2, params, grads); + using namespace functional; + Element(_1 = _1 - (multiplyFactor_ * eta_) * _2, params, grads); cudaStreamSynchronize(0); } @@ -21,9 +22,11 @@ void Adagrad::updateImpl(Tensor params, Tensor grads) { gt_->set(0); } - Element(_1 += (_2 * _2), gt_, grads); + using namespace functional; - Element(_1 -= ((multiplyFactor_ * eta_) / (Sqrt(_2) + eps_)) * _3, + Element(_1 = _1 + (_2 * _2), gt_, grads); + + Element(_1 = _1 - ((multiplyFactor_ * eta_) / (sqrt(_2) + eps_)) * _3, params, gt_, grads); @@ -55,11 +58,13 @@ void Adam::updateImpl(Tensor params, Tensor grads) { float denom1 = 1 - std::pow(beta1_, t_); float denom2 = 1 - std::pow(beta2_, t_); + using namespace functional; + Element(_1 = (beta1_ * _1) + ((1 - beta1_) * _2), mt_, grads); Element(_1 = (beta2_ * _1) + ((1 - beta2_) * (_2 * _2)), vt_, grads); - Element(_1 -= (multiplyFactor_ * eta_) * (_2 / denom1) - / (Sqrt(_3 / denom2) + eps_), + Element(_1 = _1 - (multiplyFactor_ * eta_) * (_2 / denom1) + / (sqrt(_3 / denom2) + eps_), params, mt_, vt_); diff --git a/src/tests/tensor_test.cu b/src/tests/tensor_test.cu index 42bce1f1..ba9be80c 100644 --- a/src/tests/tensor_test.cu +++ b/src/tests/tensor_test.cu @@ -1,31 +1,44 @@ -#include <boost/timer/timer.hpp> #include <iostream> -#include <map> -#include "marian.h" -#include "rnn/rnn.h" +#include "gpu/placeholders.h" +#include "gpu/functions.h" int main(int argc, char** argv) { - using namespace marian; - using namespace keywords; - auto graph = New<ExpressionGraph>(); - graph->setDevice(0); + using namespace marian::functional; - auto in1 = graph->constant({2, 2, 4}, init=inits::from_value(1)); - auto in2 = graph->constant({2, 2, 4}, init=inits::from_value(2)); - auto in3 = graph->constant({2, 2, 4}, init=inits::from_value(3)); - auto in4 = graph->constant({2, 2, 4}, init=inits::from_value(4)); + auto func = _1 = tanh(_2) * 3; - auto out = concatenate({in1, in2, in3, in4}, axis=1); + float z; + std::cerr << func(z, 2.f) << " " << sizeof(func) << std::endl; + std::cerr << z << " " << std::endl; + return 0; +} - debug(in1, "in1"); - debug(in2, "in2"); - debug(in3, "in3"); - debug(in4, "in4"); - debug(out, "out"); +/* - graph->forward(); +struct SwishNodeOp : public UnaryNodeOp { + template <typename... Args> + SwishNodeOp(Args... args) : UnaryNodeOp(args...) {} - return 0; -} + NodeOps forwardOps() { + + using namespace gpu::m; + ref<1> x; + auto swish = x * logit(x); + + return {NodeOp(Element(swish, val_, child(0)->val()))}; + } + + NodeOps backwardOps() { + + using namespace gpu::m; + ref<0> dJdf; + ref<1> x; + ref<2> f; + auto dJdx = dJdf * (f + logit(x) * (1 - f)); + + return {NodeOp(Add(dJdx, child(0)->grad(), adj_, child(0)->val(), val_))}; + } + + */
\ No newline at end of file diff --git a/src/training/graph_group_async.cu b/src/training/graph_group_async.cu index d9511501..3f63eed2 100644 --- a/src/training/graph_group_async.cu +++ b/src/training/graph_group_async.cu @@ -1,7 +1,7 @@ #include "training/graph_group_async.h" #include "kernels/tensor_operators.h" -#include "kernels/thrust_functions.h" +#include "gpu/functions.h" namespace marian { @@ -72,6 +72,7 @@ void AsyncGraphGroup::updateMovingAverage(Tensor paramsAvg, Tensor params, size_t batches) { float decay = min(mvDecay_, (float)(batches + 1) / (float)(batches + 10)); + using namespace functional; Element(_1 = (decay * _1) + ((1.f - decay) * _2), paramsAvg, params); } |