diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-11-02 22:58:56 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2017-11-02 22:58:56 +0300 |
commit | 9d06d786ba568e74f54be98f5a23d9b53d87ed7b (patch) | |
tree | 8d82c2e2c33d9fa113d5bd9fd52bb8cd099db3bd | |
parent | ca64c429e4aa4dcd49b68bab6a7d744fe06b44c2 (diff) |
farewell thrust
-rw-r--r-- | src/3rd_party/reduce_all.h | 4 | ||||
-rw-r--r-- | src/graph/node_operators_binary.h | 24 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 14 | ||||
-rw-r--r-- | src/kernels/tensor_operators.cu | 16 | ||||
-rw-r--r-- | src/kernels/tensor_operators.h | 5 | ||||
-rw-r--r-- | src/layers/param_initializers.cu | 2 | ||||
-rw-r--r-- | src/training/graph_group_async.cu | 3 | ||||
-rw-r--r-- | src/training/graph_group_singleton.cu | 1 | ||||
-rw-r--r-- | src/training/graph_group_sync.cu | 5 |
9 files changed, 58 insertions, 16 deletions
diff --git a/src/3rd_party/reduce_all.h b/src/3rd_party/reduce_all.h index d6aad524..b506a436 100644 --- a/src/3rd_party/reduce_all.h +++ b/src/3rd_party/reduce_all.h @@ -90,7 +90,7 @@ reduceBlock(volatile float *sdata, float mySum, const unsigned int tid) template <unsigned int blockSize, bool nIsPow2, class Functor> __device__ void -reduceBlocks(Functor f, const float *g_idata, float *g_odata, unsigned int n) +reduceBlocks(Functor f, float *g_idata, float *g_odata, unsigned int n) { extern __shared__ float sdata[]; @@ -147,7 +147,7 @@ cudaError_t setRetirementCount(int retCnt) // the "reduction" sample in the CUDA SDK. template <unsigned int blockSize, bool nIsPow2, class Functor> -__global__ void reduceSinglePass(Functor f, const float *g_idata, float *g_odata, unsigned int n) +__global__ void reduceSinglePass(Functor f, float *g_idata, float *g_odata, unsigned int n) { // diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index f5fe6cf5..86fafdfb 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -193,6 +193,8 @@ struct AffineNodeOp : public NaryNodeOp { } NodeOps forwardOps() { + using namespace functional; + return { NodeOp(Prod(std::static_pointer_cast<BackendGPU>(getBackend()) ->getCublasHandle(), @@ -207,6 +209,8 @@ struct AffineNodeOp : public NaryNodeOp { } NodeOps backwardOps() { + using namespace functional; + // D is the adjoint, the matrix of derivatives // df/dA += D*B.T // df/dB += A.T*D @@ -405,10 +409,14 @@ struct ScalarProductNodeOp : public NaryNodeOp { } NodeOps forwardOps() { + using namespace functional; + return {NodeOp(Reduce(_1 * _2, val_, child(0)->val(), child(1)->val()))}; } NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1 * _2, child(0)->grad(), child(1)->val(), adj_)), NodeOp(Add(_1 * _2, child(1)->grad(), child(0)->val(), adj_))}; } @@ -435,11 +443,15 @@ struct PlusNodeOp : public ElementBinaryNodeOp { PlusNodeOp(Args... args) : ElementBinaryNodeOp(args...) {} NodeOps forwardOps() { + using namespace functional; + return { NodeOp(Element(_1 = _2 + _3, val_, child(0)->val(), child(1)->val()))}; } NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1, child(0)->grad(), adj_)), NodeOp(Add(_1, child(1)->grad(), adj_))}; } @@ -452,11 +464,15 @@ struct MinusNodeOp : public ElementBinaryNodeOp { MinusNodeOp(Args... args) : ElementBinaryNodeOp(args...) {} NodeOps forwardOps() { + using namespace functional; + return { NodeOp(Element(_1 = _2 - _3, val_, child(0)->val(), child(1)->val()))}; } NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1, child(0)->grad(), adj_)), NodeOp(Add(-_1, child(1)->grad(), adj_))}; } @@ -469,11 +485,15 @@ struct MultNodeOp : public ElementBinaryNodeOp { MultNodeOp(Args... args) : ElementBinaryNodeOp(args...) {} NodeOps forwardOps() { + using namespace functional; + return { NodeOp(Element(_1 = _2 * _3, val_, child(0)->val(), child(1)->val()))}; } NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1 * _2, child(0)->grad(), adj_, child(1)->val())), NodeOp(Add(_1 * _2, child(1)->grad(), adj_, child(0)->val()))}; } @@ -486,11 +506,15 @@ struct DivNodeOp : public ElementBinaryNodeOp { DivNodeOp(Args... args) : ElementBinaryNodeOp(args...) {} NodeOps forwardOps() { + using namespace functional; + return { NodeOp(Element(_1 = _2 / _3, val_, child(0)->val(), child(1)->val()))}; } NodeOps backwardOps() { + using namespace functional; + return { NodeOp(Add(_1 * 1.0f / _2, child(0)->grad(), adj_, child(1)->val())), NodeOp(Add(-_1 * _2 / (_3 * _3), diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 05294bee..9d5b8287 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -55,7 +55,10 @@ public: return {NodeOp(Element(_1 = _2 + scalar_, val_, child(0)->val()))}; } - NodeOps backwardOps() { return {NodeOp(Add(_1, child(0)->grad(), adj_))}; } + NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1, child(0)->grad(), adj_))}; + } const std::string type() { return "scalar_add"; } }; @@ -392,9 +395,14 @@ struct SumNodeOp : public UnaryNodeOp { SumNodeOp(Expr a, Args... args) : UnaryNodeOp(a, keywords::shape = newShape(a, args...), args...) {} - NodeOps forwardOps() { return {NodeOp(Reduce(_1, val_, child(0)->val()))}; } + NodeOps forwardOps() { + using namespace functional; + + return {NodeOp(Reduce(_1, val_, child(0)->val()))}; } - NodeOps backwardOps() { return {NodeOp(Add(_1, child(0)->grad(), adj_))}; } + NodeOps backwardOps() { + using namespace functional; + return {NodeOp(Add(_1, child(0)->grad(), adj_))}; } template <class... Args> Shape newShape(Expr a, Args... args) { diff --git a/src/kernels/tensor_operators.cu b/src/kernels/tensor_operators.cu index fc38042c..61abbc11 100644 --- a/src/kernels/tensor_operators.cu +++ b/src/kernels/tensor_operators.cu @@ -24,12 +24,13 @@ __device__ inline float stableLogit(float x) { } bool IsNan(Tensor in) { - cudaSetDevice(in->getDevice()); - thrust::device_ptr<float> begin = thrust::device_pointer_cast(in->data()); - thrust::device_ptr<float> end - = thrust::device_pointer_cast(in->data() + in->size()); - return thrust::transform_reduce( - begin, end, isnan_test(), 0, thrust::plus<bool>()); + //cudaSetDevice(in->getDevice()); + //thrust::device_ptr<float> begin = thrust::device_pointer_cast(in->data()); + //thrust::device_ptr<float> end + // = thrust::device_pointer_cast(in->data() + in->size()); + //return thrust::transform_reduce( + // begin, end, isnan_test(), 0, thrust::plus<bool>()); + return false; } void ConcatCont(Tensor out, const std::vector<Tensor>& inputs, int axis) { @@ -1272,12 +1273,15 @@ void CrossEntropyPickBackward(Tensor out, Tensor adj, Tensor a, Tensor pick) { } float L2Norm(Tensor in) { + using namespace functional; + cudaSetDevice(in->getDevice()); uint8_t* data; cudaMalloc(&data, sizeof(float)); Tensor out(new TensorBase( New<MemoryPiece>(data, sizeof(float)), {1, 1}, in->getDevice())); + ReduceAll(_1 * _1, out, in); float dataCpu = sqrtf(out->get(0)); out.reset(); diff --git a/src/kernels/tensor_operators.h b/src/kernels/tensor_operators.h index 54a14dd6..b41cc441 100644 --- a/src/kernels/tensor_operators.h +++ b/src/kernels/tensor_operators.h @@ -1,8 +1,7 @@ #pragma once #include <cublas_v2.h> -#include <thrust/device_vector.h> -#include <thrust/host_vector.h> + #include <thrust/pair.h> #include "tensors/tensor.h" @@ -13,12 +12,12 @@ #include "gpu/shape.h" #include "gpu/tmp.h" #include "gpu/tensor.h" +#include "gpu/functions.h" namespace marian { bool IsNan(Tensor in); -using namespace thrust::placeholders; const int MAX_THREADS = 512; const int MAX_BLOCKS = 65535; diff --git a/src/layers/param_initializers.cu b/src/layers/param_initializers.cu index 3b3cc2a4..f06c7c43 100644 --- a/src/layers/param_initializers.cu +++ b/src/layers/param_initializers.cu @@ -142,6 +142,8 @@ std::function<void(Tensor)> from_word2vec(const std::string& file, int dimVoc, int dimEmb, bool normalize /*= false*/) { + using namespace functional; + return [file, dimVoc, dimEmb, normalize](Tensor t) { auto embs = Word2VecReader().read(file, dimVoc, dimEmb); t->set(embs); diff --git a/src/training/graph_group_async.cu b/src/training/graph_group_async.cu index 3f63eed2..706a85a5 100644 --- a/src/training/graph_group_async.cu +++ b/src/training/graph_group_async.cu @@ -185,7 +185,8 @@ void AsyncGraphGroup::execute(Ptr<data::Batch> batch) { accGradients->set(0); } - Element(_1 += _2, accGradients, graph->params()->grads()); + using namespace functional; + Element(_1 = _1 + _2, accGradients, graph->params()->grads()); gradients = accGradients; // Keep track of how many words we've calculated the error from diff --git a/src/training/graph_group_singleton.cu b/src/training/graph_group_singleton.cu index 1df8ea78..0b0958ae 100644 --- a/src/training/graph_group_singleton.cu +++ b/src/training/graph_group_singleton.cu @@ -13,6 +13,7 @@ void SingletonGraph::setScheduler(Ptr<Scheduler> scheduler) { void SingletonGraph::updateMovingAverage(Tensor mvAvgParams, Tensor params, size_t batches) { + using namespace functional; float decay = min(mvDecay_, (float)(batches + 1) / (float)(batches + 10)); Element(_1 = (decay * _1) + ((1.f - decay) * _2), mvAvgParams, params); } diff --git a/src/training/graph_group_sync.cu b/src/training/graph_group_sync.cu index bca443a7..f38d2618 100644 --- a/src/training/graph_group_sync.cu +++ b/src/training/graph_group_sync.cu @@ -15,6 +15,7 @@ void SyncGraphGroup::setScheduler(Ptr<Scheduler> scheduler) { void SyncGraphGroup::updateMovingAverage(Tensor paramsAvg, Tensor params, size_t batches) { + using namespace functional; float decay = min(mvDecay_, (float)(batches + 1) / (float)(batches + 10)); Element(_1 = (decay * _1) + ((1.f - decay) * _2), paramsAvg, params); } @@ -130,7 +131,9 @@ void SyncGraphGroup::execute(Ptr<data::Batch> batch) { if(batches[i]->size() > 0) { auto subGrad = graph->params()->grads()->subtensor(pos, size); tmpTensors_[idx]->copyFrom(subGrad); - Element(_1 += _2, grads_[idx], tmpTensors_[idx]); + + using namespace functional; + Element(_1 = _1 + _2, grads_[idx], tmpTensors_[idx]); } i++; } |