diff options
author | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-02-16 22:59:12 +0300 |
---|---|---|
committer | Marcin Junczys-Dowmunt <junczys@amu.edu.pl> | 2018-02-16 22:59:12 +0300 |
commit | 327cfc1cc3fbe3ab92927c800aa30aef1c41d517 (patch) | |
tree | 9e67c0035ebb7d2c022982443efc1ad338ebc615 /src | |
parent | dd296e77f76143033fc1589c7dce6d12196bbfdd (diff) |
pass through backend
Diffstat (limited to 'src')
-rw-r--r-- | src/graph/expression_graph.cpp | 6 | ||||
-rw-r--r-- | src/graph/node_operators_unary.h | 24 | ||||
-rw-r--r-- | src/graph/parameters.h | 6 | ||||
-rw-r--r-- | src/kernels/sparse.cu | 6 | ||||
-rw-r--r-- | src/kernels/sparse.h | 33 | ||||
-rw-r--r-- | src/kernels/tensor_operators.cu | 6 | ||||
-rw-r--r-- | src/optimizers/optimizers.cu | 4 | ||||
-rw-r--r-- | src/tensors/cpu/dropout.cpp | 2 | ||||
-rw-r--r-- | src/tensors/dispatch.h | 25 | ||||
-rw-r--r-- | src/tensors/gpu/dropout.cu | 5 | ||||
-rw-r--r-- | src/tensors/tensor.cu | 16 | ||||
-rw-r--r-- | src/tensors/tensor.h | 12 | ||||
-rw-r--r-- | src/tensors/tensor_allocator.h | 10 | ||||
-rw-r--r-- | src/training/graph_group_async.cu | 14 | ||||
-rw-r--r-- | src/training/graph_group_async_drop.cu | 21 | ||||
-rw-r--r-- | src/training/graph_group_async_drop.h | 2 | ||||
-rw-r--r-- | src/training/graph_group_multinode.cu | 15 | ||||
-rw-r--r-- | src/training/graph_group_multinode.h | 2 | ||||
-rw-r--r-- | src/training/graph_group_sync.cu | 4 | ||||
-rw-r--r-- | src/training/sparse_tensor.cu | 24 | ||||
-rw-r--r-- | src/training/sparse_tensor.h | 8 |
21 files changed, 116 insertions, 129 deletions
diff --git a/src/graph/expression_graph.cpp b/src/graph/expression_graph.cpp index 183b5787..934e2b73 100644 --- a/src/graph/expression_graph.cpp +++ b/src/graph/expression_graph.cpp @@ -12,15 +12,15 @@ void ExpressionGraph::setDevice(DeviceId deviceId) { if(!backend_) { backend_ = BackendByDevice(deviceId, Config::seed); params_ = New<Parameters>(); - params_->init(backend_->getDevice()); - tensors_ = New<TensorAllocator>(backend_->getDevice()); + params_->init(backend_); + tensors_ = New<TensorAllocator>(backend_); } } Expr ExpressionGraph::dropout(float prob, Shape shape) { return Expression<ConstantNode>(shared_from_this(), keywords::init = [prob, this](Tensor t) { - Dropout(backend_, t, prob); + Dropout(t, prob); }, keywords::shape = shape); } diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index a3f27fd2..0170fc73 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -211,21 +211,7 @@ struct TanhNodeOp : public NaryNodeOp { const std::string type() { return "tanh"; } }; -/** - * Represents a <a - * href="https://en.wikipedia.org/wiki/Rectifier_(neural_networks)">rectified - * linear</a> node in an expression graph. - * - * This node implements the activation function \f$ f(x) = \max(0, x) \f$ and - * its derivative: - * \f[ - * f^\prime(x) = - * \begin{cases} - * 0 & \text{if } x \leq 0 \\ - * 1 & \text{if } x > 0 - * \end{cases} - * \f] - */ + struct ReLUNodeOp : public UnaryNodeOp { template <typename... Args> ReLUNodeOp(Args... args) : UnaryNodeOp(args...) {} @@ -877,14 +863,14 @@ public: Tensor& val() { auto childVal = reshapee_->val(); val_.reset( - new TensorBase(childVal->memory(), shape(), childVal->getDevice())); + new TensorBase(childVal->memory(), shape(), childVal->getBackend())); return val_; }; Tensor& grad() { auto childGrad = reshapee_->grad(); adj_.reset( - new TensorBase(childGrad->memory(), shape(), childGrad->getDevice())); + new TensorBase(childGrad->memory(), shape(), childGrad->getBackend())); return adj_; }; @@ -953,7 +939,7 @@ public: size_t offset = step_ * shape().elements() * sizeof(float); auto mem = New<MemoryPiece>(childVal->memory()->data() + offset, childVal->memory()->size()); - val_.reset(new TensorBase(mem, shape(), childVal->getDevice())); + val_.reset(new TensorBase(mem, shape(), childVal->getBackend())); return val_; }; @@ -962,7 +948,7 @@ public: size_t offset = step_ * shape().elements() * sizeof(float); auto mem = New<MemoryPiece>(childGrad->memory()->data() + offset, childGrad->memory()->size()); - adj_.reset(new TensorBase(mem, shape(), childGrad->getDevice())); + adj_.reset(new TensorBase(mem, shape(), childGrad->getBackend())); return adj_; }; diff --git a/src/graph/parameters.h b/src/graph/parameters.h index ed8b7690..3f282e4a 100644 --- a/src/graph/parameters.h +++ b/src/graph/parameters.h @@ -20,9 +20,9 @@ private: Ptr<TensorAllocator> grads_; public: - void init(DeviceId deviceId) { - vals_ = New<TensorAllocator>(deviceId); - grads_ = New<TensorAllocator>(deviceId); + void init(Ptr<Backend> backend) { + vals_ = New<TensorAllocator>(backend); + grads_ = New<TensorAllocator>(backend); } auto begin() -> decltype(params_.begin()) { return params_.begin(); } diff --git a/src/kernels/sparse.cu b/src/kernels/sparse.cu index 1d104474..b5080c0c 100644 --- a/src/kernels/sparse.cu +++ b/src/kernels/sparse.cu @@ -12,7 +12,7 @@ void multiply(Ptr<CSR> C, const Ptr<CSR> B, bool transA, bool transB) { - cudaSetDevice(C->getDevice()); + cudaSetDevice(backend_->getDevice().no); int nnzTotal; C->allocRowIndices(A->rows()); CUSPARSE_CHECK(cusparseXcsrgemmNnz( @@ -91,7 +91,7 @@ void multiply(Ptr<CSR> C, //} void LfaForward(Tensor out, Tensor logits, Tensor att, Ptr<CSR> sparseLf) { - cudaSetDevice(out->getDevice()); + cudaSetDevice(backend_->getDevice().no); int batch = att->shape()[0]; int srcWords = att->shape()[2]; @@ -150,7 +150,7 @@ __global__ void gCollapseAtt(float* out, } void CollapseAtt(Tensor out, Tensor in) { - cudaSetDevice(out->getDevice()); + cudaSetDevice(backend_->getDevice().no); int nonzeros = out->shape().elements(); int batch = out->shape()[0]; int srcWords = out->shape()[2]; diff --git a/src/kernels/sparse.h b/src/kernels/sparse.h index d70555f1..cffb398e 100644 --- a/src/kernels/sparse.h +++ b/src/kernels/sparse.h @@ -14,7 +14,7 @@ private: int nnz_{0}; int rows_{0}; int cols_{0}; - DeviceId deviceId_; + Ptr<Backend> backend_; cusparseHandle_t handle_{0}; cusparseMatDescr_t descr_{0}; @@ -24,9 +24,9 @@ private: float* values_{0}; public: - CSR(int rows, int cols, DeviceId deviceId) - : rows_(rows), cols_(cols), deviceId_(deviceId) { - cudaSetDevice(deviceId_.no); + CSR(int rows, int cols, Ptr<Backend> backend) + : rows_(rows), cols_(cols), backend_(backend) { + cudaSetDevice(backend_->getDevice().no); CUSPARSE_CHECK(cusparseCreate(&handle_)); CUSPARSE_CHECK(cusparseCreateMatDescr(&descr_)); CUSPARSE_CHECK(cusparseSetMatType(descr_, CUSPARSE_MATRIX_TYPE_GENERAL)); @@ -38,9 +38,9 @@ public: const std::vector<float>& values, const std::vector<int>& rowIndices, const std::vector<int>& colIndices, - DeviceId deviceId) - : nnz_(values.size()), rows_(rows), cols_(cols), deviceId_(deviceId) { - cudaSetDevice(deviceId_.no); + Ptr<Backend> backend) + : nnz_(values.size()), rows_(rows), cols_(cols), backend_(backend) { + cudaSetDevice(backend_->getDevice().no); CUSPARSE_CHECK(cusparseCreate(&handle_)); CUSPARSE_CHECK(cusparseCreateMatDescr(&descr_)); CUSPARSE_CHECK(cusparseSetMatType(descr_, CUSPARSE_MATRIX_TYPE_GENERAL)); @@ -73,8 +73,8 @@ public: CUDA_CHECK(cudaFree(cooRowIndices)); } - CSR(Tensor dense) : deviceId_(dense->getDevice()) { - cudaSetDevice(deviceId_.no); + CSR(Tensor dense) : backend_(dense->getBackend()) { + cudaSetDevice(backend_->getDevice().no); rows_ = dense->shape()[0] * dense->shape()[2] * dense->shape()[3]; cols_ = dense->shape()[1]; @@ -114,7 +114,7 @@ public: } ~CSR() { - cudaSetDevice(deviceId_.no); + cudaSetDevice(backend_->getDevice().no); if(values_) CUDA_CHECK(cudaFree(values_)); if(rowIndices_) @@ -129,7 +129,7 @@ public: } void toTensor(Tensor dense) { - cudaSetDevice(deviceId_.no); + cudaSetDevice(backend_->getDevice().no); ABORT_IF(dense->size() != rows_ * cols_, "Matrix sizes do not match"); cusparseScsc2dense(handle_, @@ -154,10 +154,10 @@ public: int* rowIndices() { return rowIndices_; } int* colIndices() { return colIndices_; } - DeviceId getDevice() { return deviceId_; } + DeviceId getDevice() { return backend_->getDevice(); } void allocValues(int nnz = 0) { - cudaSetDevice(deviceId_.no); + cudaSetDevice(backend_->getDevice().no); if(nnz > 0) nnz_ = nnz; if(values_) @@ -166,7 +166,7 @@ public: } void allocRowIndices(int rows = 0) { - cudaSetDevice(deviceId_.no); + cudaSetDevice(backend_->getDevice().no); if(rows > 0) rows_ = rows; if(rowIndices_) @@ -175,7 +175,7 @@ public: } void allocColIndices(int nnz = 0) { - cudaSetDevice(deviceId_.no); + cudaSetDevice(backend_->getDevice().no); if(nnz > 0) nnz_ = nnz; if(colIndices_) @@ -184,11 +184,12 @@ public: } std::string debug() { + cudaSetDevice(backend_->getDevice().no); uint8_t* buffer; CUDA_CHECK(cudaMalloc(&buffer, sizeof(float) * rows() * cols())); auto mem = New<MemoryPiece>(buffer, sizeof(float) * rows() * cols()); - Tensor tensor(new TensorBase(mem, {rows(), cols()}, deviceId_)); + Tensor tensor(new TensorBase(mem, {rows(), cols()}, backend_)); toTensor(tensor); std::string temp = tensor->debug(); diff --git a/src/kernels/tensor_operators.cu b/src/kernels/tensor_operators.cu index 69b8afc4..87019d9e 100644 --- a/src/kernels/tensor_operators.cu +++ b/src/kernels/tensor_operators.cu @@ -893,7 +893,7 @@ void Select(Ptr<Allocator> allocator, auto mp_indices = allocator->alloc<size_t>(indices.size()); CudaCopy(indices.data(), indices.data() + indices.size(), mp_indices->data()); - + int axisGPU = axis + gpu::Shape::size() - out->shape().size(); gSelect<<<blocks, threads>>>(out->data(), out->shape(), @@ -919,7 +919,7 @@ void Insert(Ptr<Allocator> allocator, auto mp_indices = allocator->alloc<size_t>(indices.size()); CudaCopy(indices.data(), indices.data() + indices.size(), mp_indices->data()); - + int axisGPU = axis + gpu::Shape::size() - out->shape().size(); gInsert<<<blocks, threads>>>(out->data(), out->shape(), @@ -1295,7 +1295,7 @@ float L2Norm(Tensor in) { uint8_t* data; cudaMalloc(&data, blocks * sizeof(float)); Tensor out(new TensorBase( - New<MemoryPiece>(data, blocks * sizeof(float)), {1, blocks}, in->getDevice())); + New<MemoryPiece>(data, blocks * sizeof(float)), {1, blocks}, in->getBackend())); ReduceAll(_1 * _1, out, in); float dataCpu = sqrtf(out->get(0)); diff --git a/src/optimizers/optimizers.cu b/src/optimizers/optimizers.cu index 2874e1d5..e82800c9 100644 --- a/src/optimizers/optimizers.cu +++ b/src/optimizers/optimizers.cu @@ -13,7 +13,7 @@ void Sgd::updateImpl(Tensor params, Tensor grads) { void Adagrad::updateImpl(Tensor params, Tensor grads) { if(!alloc_) - alloc_ = New<TensorAllocator>(params->getDevice()); + alloc_ = New<TensorAllocator>(params->getBackend()); if(!gt_) { int elements = params->size(); @@ -42,7 +42,7 @@ void Adagrad::resetStats() { void Adam::updateImpl(Tensor params, Tensor grads) { if(!alloc_) - alloc_ = New<TensorAllocator>(params->getDevice()); + alloc_ = New<TensorAllocator>(params->getBackend()); if(!mt_) { int elements = params->size(); diff --git a/src/tensors/cpu/dropout.cpp b/src/tensors/cpu/dropout.cpp index 4286042b..cc6cea41 100644 --- a/src/tensors/cpu/dropout.cpp +++ b/src/tensors/cpu/dropout.cpp @@ -5,7 +5,7 @@ namespace marian { namespace cpu { - void Dropout(Ptr<marian::Backend> backend, Tensor tensor, float p) { + void Dropout(Tensor tensor, float p) { ABORT("Not implemented"); std::fill(tensor->data(), tensor->data() + tensor->size(), p); } diff --git a/src/tensors/dispatch.h b/src/tensors/dispatch.h index 14b74f5a..e63a6af1 100644 --- a/src/tensors/dispatch.h +++ b/src/tensors/dispatch.h @@ -1,35 +1,34 @@ #pragma once
#include "common/definitions.h"
-#include "tensors/backend.h"
#include "tensors/tensor.h"
#define DISPATCH1(Function, Arg1) \
namespace gpu { \
- void Function(Ptr<marian::Backend>, Arg1); \
+ void Function(Arg1); \
} \
namespace cpu { \
- void Function(Ptr<marian::Backend>, Arg1); \
+ void Function(Arg1); \
} \
- void Function(Ptr<marian::Backend> backend, Arg1 arg1) { \
- if(backend->getDevice().type == DeviceType::gpu) \
- gpu::Function(backend, arg1); \
+ void Function(Arg1 arg1) { \
+ if(arg1->getBackend()->getDevice().type == DeviceType::gpu) \
+ gpu::Function(arg1); \
else \
- cpu::Function(backend, arg1); \
+ cpu::Function(arg1); \
}
#define DISPATCH2(Function, Arg1, Arg2) \
namespace gpu { \
- void Function(Ptr<marian::Backend>, Arg1, Arg2); \
+ void Function(Arg1, Arg2); \
} \
namespace cpu { \
- void Function(Ptr<marian::Backend>, Arg1, Arg2); \
+ void Function(Arg1, Arg2); \
} \
- static inline void Function(Ptr<marian::Backend> backend, Arg1 arg1, Arg2 arg2) { \
- if(backend->getDevice().type == DeviceType::gpu) \
- gpu::Function(backend, arg1, arg2); \
+ static inline void Function(Arg1 arg1, Arg2 arg2) { \
+ if(arg1->getBackend()->getDevice().type == DeviceType::gpu) \
+ gpu::Function(arg1, arg2); \
else \
- cpu::Function(backend, arg1, arg2); \
+ cpu::Function(arg1, arg2); \
}
namespace marian {
diff --git a/src/tensors/gpu/dropout.cu b/src/tensors/gpu/dropout.cu index 6dc49c51..4a4223a8 100644 --- a/src/tensors/gpu/dropout.cu +++ b/src/tensors/gpu/dropout.cu @@ -35,8 +35,9 @@ namespace marian { } } - void Dropout(Ptr<marian::Backend> backend, Tensor tensor, float p) { - curandGenerator_t gen = std::static_pointer_cast<gpu::Backend>(backend)->getCurandGenerator(); + void Dropout(Tensor tensor, float p) { + auto gpuBackend = std::static_pointer_cast<gpu::Backend>(tensor->getBackend()); + curandGenerator_t gen = gpuBackend->getCurandGenerator(); int n = tensor->size(); CURAND_CALL(curandGenerateUniform(gen, tensor->data(), n)); diff --git a/src/tensors/tensor.cu b/src/tensors/tensor.cu index bc26fcec..96d979bf 100644 --- a/src/tensors/tensor.cu +++ b/src/tensors/tensor.cu @@ -16,7 +16,7 @@ __global__ void gFill(float *d_in, int size, float val) { } float TensorBase::get(size_t i) { - cudaSetDevice(deviceId_.no); + CUDA_CHECK(cudaSetDevice(backend_->getDevice().no)); float temp; CUDA_CHECK( cudaMemcpy(&temp, data() + i, sizeof(float), cudaMemcpyDeviceToHost)); @@ -25,14 +25,14 @@ float TensorBase::get(size_t i) { } void TensorBase::set(size_t i, float value) { - cudaSetDevice(deviceId_.no); + CUDA_CHECK(cudaSetDevice(backend_->getDevice().no)); CUDA_CHECK( cudaMemcpy(data() + i, &value, sizeof(float), cudaMemcpyHostToDevice)); cudaStreamSynchronize(0); } void TensorBase::get(std::vector<float> &v) { - CUDA_CHECK(cudaSetDevice(deviceId_.no)); + CUDA_CHECK(cudaSetDevice(backend_->getDevice().no)); v.resize(size()); CUDA_CHECK(cudaMemcpy( v.data(), data(), size() * sizeof(float), cudaMemcpyDeviceToHost)); @@ -40,7 +40,7 @@ void TensorBase::get(std::vector<float> &v) { } void TensorBase::set(float value) { - cudaSetDevice(deviceId_.no); + CUDA_CHECK(cudaSetDevice(backend_->getDevice().no)); int threads = std::min(512, (int)size()); int blocks = (size() / threads) + (size() % threads != 0); gFill<<<blocks, threads>>>(data(), size(), value); @@ -48,7 +48,7 @@ void TensorBase::set(float value) { } void TensorBase::set(const std::vector<float> &v) { - CUDA_CHECK(cudaSetDevice(deviceId_.no)); + CUDA_CHECK(cudaSetDevice(backend_->getDevice().no)); CUDA_CHECK(cudaMemcpy( data(), v.data(), v.size() * sizeof(float), cudaMemcpyHostToDevice)); cudaStreamSynchronize(0); @@ -56,13 +56,13 @@ void TensorBase::set(const std::vector<float> &v) { void TensorBase::setSparse(const std::vector<size_t> &k, const std::vector<float> &v) { - cudaSetDevice(deviceId_.no); + CUDA_CHECK(cudaSetDevice(backend_->getDevice().no)); SetSparse(data(), k, v); cudaStreamSynchronize(0); } void TensorBase::copyFrom(Tensor in) { - cudaSetDevice(deviceId_.no); + CUDA_CHECK(cudaSetDevice(backend_->getDevice().no)); CUDA_CHECK(cudaMemcpy(data(), (float *)in->data(), in->size() * sizeof(float), @@ -74,7 +74,7 @@ std::string TensorBase::debug() { std::stringstream strm; assert(shape_.size()); strm << shape_; - strm << " device=" << deviceId_; + strm << " device=" << backend_->getDevice(); strm << " ptr=" << (size_t)memory_->data(); strm << " bytes=" << memory_->size(); strm << std::endl; diff --git a/src/tensors/tensor.h b/src/tensors/tensor.h index 15d3a427..309cacd5 100644 --- a/src/tensors/tensor.h +++ b/src/tensors/tensor.h @@ -9,6 +9,7 @@ #include "common/definitions.h" #include "common/shape.h" #include "tensors/memory_piece.h" +#include "tensors/backend.h" namespace marian { @@ -16,11 +17,11 @@ class TensorBase : public std::enable_shared_from_this<TensorBase> { private: Ptr<MemoryPiece> memory_; Shape shape_; - DeviceId deviceId_; + Ptr<Backend> backend_; public: - TensorBase(Ptr<MemoryPiece> memory, Shape shape, DeviceId deviceId) - : memory_(memory), shape_(shape), deviceId_(deviceId) {} + TensorBase(Ptr<MemoryPiece> memory, Shape shape, Ptr<Backend> backend) + : memory_(memory), shape_(shape), backend_(backend) {} ~TensorBase() {} @@ -39,12 +40,13 @@ public: return get(0); } - DeviceId getDevice() { return deviceId_; } + Ptr<Backend> getBackend() { return backend_; } + DeviceId getDevice() { return backend_->getDevice(); } Tensor subtensor(int offset, int size) { auto mem = New<MemoryPiece>(memory_->data() + sizeof(float) * offset, sizeof(float) * size); - return Tensor(new TensorBase(mem, {1, size}, deviceId_)); + return New<TensorBase>(mem, Shape{1, size}, backend_); } float get(size_t i); diff --git a/src/tensors/tensor_allocator.h b/src/tensors/tensor_allocator.h index e1c54b22..18aae134 100644 --- a/src/tensors/tensor_allocator.h +++ b/src/tensors/tensor_allocator.h @@ -16,11 +16,13 @@ private: const size_t GROW = CHUNK * MBYTE; const size_t ALIGN = 256; + Ptr<Backend> backend_; Ptr<Allocator> allocator_; public: - TensorAllocator(DeviceId deviceId) - : allocator_(New<Allocator>(deviceId, 0, GROW, ALIGN)) {} + TensorAllocator(Ptr<Backend> backend) + : backend_(backend), + allocator_(New<Allocator>(backend_->getDevice(), 0, GROW, ALIGN)) {} ~TensorAllocator() { clear(); } @@ -58,7 +60,7 @@ public: if(!t || t->shape() != shape) { int size = shape.elements(); auto mem = allocator_->alloc<float>(size); - t = Tensor(new TensorBase(mem, shape, allocator_->getDevice())); + t = Tensor(new TensorBase(mem, shape, backend_)); } } @@ -67,7 +69,7 @@ public: Tensor asTensor() { auto mem = allocator_->memory(); int size = mem->size() / sizeof(float); - return Tensor(new TensorBase(mem, {1, size}, allocator_->getDevice())); + return Tensor(new TensorBase(mem, {1, size}, backend_)); } size_t size() { return allocator_->size() / sizeof(float); } diff --git a/src/training/graph_group_async.cu b/src/training/graph_group_async.cu index 90d575bd..18f3908a 100644 --- a/src/training/graph_group_async.cu +++ b/src/training/graph_group_async.cu @@ -93,12 +93,12 @@ void AsyncGraphGroup::init(Ptr<data::Batch> batch) { int pos = 0; // parameter sharding - for(auto device : devices_) { + for(auto graph : graphs_) { int __size__ = min(shardSize_, totalSize); totalSize -= __size__; Tensor param; - Ptr<TensorAllocator> allocator = New<TensorAllocator>(DeviceId{device, DeviceType::gpu}); + Ptr<TensorAllocator> allocator = New<TensorAllocator>(graph->getBackend()); allocator->reserveExact(__size__ * sizeof(float)); allocator->allocate(param, {1, __size__}); paramsAlloc_.push_back(allocator); @@ -112,11 +112,11 @@ void AsyncGraphGroup::init(Ptr<data::Batch> batch) { if(grads_.size() == 0) { int totalSize = graphs_[0]->params()->vals()->size(); - for(auto device : devices_) { + for(auto graph : graphs_) { int __size__ = min(shardSize_, totalSize); totalSize -= __size__; Tensor grad_; - Ptr<TensorAllocator> allocator_ = New<TensorAllocator>(DeviceId{device, DeviceType::gpu}); + Ptr<TensorAllocator> allocator_ = New<TensorAllocator>(graph->getBackend()); allocator_->reserveExact(__size__ * sizeof(float)); allocator_->allocate(grad_, {1, __size__}); @@ -129,11 +129,11 @@ void AsyncGraphGroup::init(Ptr<data::Batch> batch) { int totalSize = graphs_[0]->params()->vals()->size(); int i = 0; - for(auto device : devices_) { + for(auto graph : graphs_) { int __size__ = min(shardSize_, totalSize); totalSize -= __size__; Tensor paramAvg; - Ptr<TensorAllocator> allocator = New<TensorAllocator>(DeviceId{device, DeviceType::gpu}); + Ptr<TensorAllocator> allocator = New<TensorAllocator>(graph->getBackend()); allocator->reserveExact(__size__ * sizeof(float)); allocator->allocate(paramAvg, {1, __size__}); @@ -187,7 +187,7 @@ void AsyncGraphGroup::execute(Ptr<data::Batch> batch) { Tensor gradients; if(tau_ > 1) { if(t == 0) { - accAlloc = New<TensorAllocator>(graph->getDevice()); + accAlloc = New<TensorAllocator>(graph->getBackend()); accAlloc->reserveExact(graph->params()->grads()->memory()->size()); accAlloc->allocate(accGradients, graph->params()->grads()->shape()); accGradients->set(0); diff --git a/src/training/graph_group_async_drop.cu b/src/training/graph_group_async_drop.cu index 84985009..33684bf9 100644 --- a/src/training/graph_group_async_drop.cu +++ b/src/training/graph_group_async_drop.cu @@ -8,9 +8,9 @@ namespace marian { -Tensor AsyncGraphGroupDrop::newTensor(int size, DeviceId deviceId) { +Tensor AsyncGraphGroupDrop::newTensor(int size, Ptr<Backend> backend) { Tensor t; - Ptr<TensorAllocator> allocator_ = New<TensorAllocator>(deviceId); + Ptr<TensorAllocator> allocator_ = New<TensorAllocator>(backend); allocator_->reserveExact(size * sizeof(float)); allocator_->allocate(t, {1, size}); allocators.push_back(allocator_); @@ -86,7 +86,7 @@ void AsyncGraphGroupDrop::pushGradients(Tensor newGrads, // get the sparse gradient pushDropper_[device_id]->dropGraph( - newGrads, pushSparseGradient_[device_id], + newGrads, pushSparseGradient_[device_id], droping_rate, dropping_momentum); SparseTensor newSparseGrads = pushSparseGradient_[device_id]; @@ -146,13 +146,12 @@ void AsyncGraphGroupDrop::init(Ptr<data::Batch> batch) { fetchStep_.push_back(0); pushStep_.push_back(0); - size_t device = devices_[i]; // temporary tensor to compute parameter delta before fetching - paramsDelta_.push_back(newTensor(shardSize, {device, DeviceType::gpu})); + paramsDelta_.push_back(newTensor(shardSize, graphs_[i]->getBackend())); // tensors to store local params history for(int h_id = 0; h_id < devices_.size(); h_id++) { - Tensor tmp = newTensor(params_[i]->size(), {device, DeviceType::gpu}); + Tensor tmp = newTensor(params_[i]->size(), graphs_[i]->getBackend()); tmp->copyFrom(params_[i]); paramsLocal_[h_id].push_back(tmp); } @@ -162,23 +161,23 @@ void AsyncGraphGroupDrop::init(Ptr<data::Batch> batch) { // N-dropper for fetch std::vector<GradientDrop> tmpDropper; - for(int i = 0; i < devices_.size(); i++) + for(auto device : devices_) tmpDropper.push_back(GradientDrop(new GradientDropBase())); fetchDropper.push_back(tmpDropper); // sparsetensor to store sparsified gradients per-device pushSparseGradient_.push_back( - SparseTensor(new SparseTensorBase(sparseCap, {device, DeviceType::gpu}))); + SparseTensor(new SparseTensorBase(sparseCap, graphs_[i]->getBackend()))); pushShardedSparseGradient_.push_back( - SparseTensor(new SparseTensorBase(sparseCap, {device, DeviceType::gpu}))); + SparseTensor(new SparseTensorBase(sparseCap, graphs_[i]->getBackend()))); fetchSparseGradient_.push_back(SparseTensor( - new SparseTensorBase(sparseCap / devices_.size(), {device, DeviceType::gpu}))); + new SparseTensorBase(sparseCap / devices_.size(), graphs_[i]->getBackend()))); std::vector<SparseTensor> tmp; for(int i = 0; i < devices_.size(); i++) tmp.push_back(SparseTensor( - new SparseTensorBase(sparseCap / devices_.size(), {device, DeviceType::gpu}))); + new SparseTensorBase(sparseCap / devices_.size(), graphs_[i]->getBackend()))); fetchShardedSparseGradient_.push_back(tmp); } diff --git a/src/training/graph_group_async_drop.h b/src/training/graph_group_async_drop.h index 33e74a40..f32d9444 100644 --- a/src/training/graph_group_async_drop.h +++ b/src/training/graph_group_async_drop.h @@ -31,7 +31,7 @@ class AsyncGraphGroupDrop : public AsyncGraphGroup { std::vector<Ptr<TensorAllocator>> allocators; - Tensor newTensor(int size, DeviceId deviceId); + Tensor newTensor(int size, Ptr<Backend> backend); protected: void init(Ptr<data::Batch> batch); diff --git a/src/training/graph_group_multinode.cu b/src/training/graph_group_multinode.cu index 78cd842f..34aa2b5d 100644 --- a/src/training/graph_group_multinode.cu +++ b/src/training/graph_group_multinode.cu @@ -19,9 +19,9 @@ void MultiNodeGraphGroup::setScheduler(Ptr<Scheduler> scheduler) { /** * Allocate new tensor on given GPU and store allocator. */ -Tensor MultiNodeGraphGroup::newTensor(int size, DeviceId deviceId) { +Tensor MultiNodeGraphGroup::newTensor(int size, Ptr<Backend> backend) { Tensor t; - Ptr<TensorAllocator> allocator = New<TensorAllocator>(deviceId); + Ptr<TensorAllocator> allocator = New<TensorAllocator>(backend); allocator->reserveExact(size * sizeof(float)); allocator->allocate(t, {1, size}); allocators_.push_back(allocator); @@ -148,14 +148,12 @@ void MultiNodeGraphGroup::initClientCommOverlapVars() { void MultiNodeGraphGroup::initClientCommOverlapGpuTensors() { size_t modelSize = clientGraphs_[0]->params()->vals()->size(); for(int client = 0; client < devices_.size(); client++) { - DeviceId deviceId{devices_[client], DeviceType::gpu}; - // Communication overlap buffer (for grads + params) - Tensor commOverlapBuffer = newTensor(modelSize, deviceId); + Tensor commOverlapBuffer = newTensor(modelSize, clientGraphs_[client]->getBackend()); commOverlapBuffer->copyFrom(clientGraphs_[0]->params()->vals()); clientCommOverlapBuffersGPU_.push_back(commOverlapBuffer); // Gradients local sum buffer - Tensor sumGrads = newTensor(modelSize, deviceId); + Tensor sumGrads = newTensor(modelSize, clientGraphs_[client]->getBackend()); sumGrads->set(0); clientSummedGradsGPU.push_back(sumGrads); // Local optimizer to apply summed gradients @@ -207,12 +205,11 @@ void MultiNodeGraphGroup::calculateShardSizes() { void MultiNodeGraphGroup::initShardGpuTensors() { size_t offset = 0; for(int shard = 0; shard < devices_.size(); shard++) { - DeviceId deviceId{devices_[shard], DeviceType::gpu}; - Tensor gpuParams = newTensor(shardSizes_[shard], deviceId); + Tensor gpuParams = newTensor(shardSizes_[shard], clientGraphs_[shard]->getBackend()); gpuParams->copyFrom(clientGraphs_[0]->params()->vals()->subtensor( offset, shardSizes_[shard])); shardParams_.push_back(gpuParams); - shardGrads_.push_back(newTensor(shardSizes_[shard], deviceId)); + shardGrads_.push_back(newTensor(shardSizes_[shard], clientGraphs_[shard]->getBackend())); } } diff --git a/src/training/graph_group_multinode.h b/src/training/graph_group_multinode.h index 90243f7c..c6dc495c 100644 --- a/src/training/graph_group_multinode.h +++ b/src/training/graph_group_multinode.h @@ -217,7 +217,7 @@ protected: /** * Allocate new tensor on given GPU and store allocator. */ - Tensor newTensor(int size, DeviceId deviceId); + Tensor newTensor(int size, Ptr<Backend> backend); /** * Setup training environment and launch server thread and (if enabled) client diff --git a/src/training/graph_group_sync.cu b/src/training/graph_group_sync.cu index e713b99d..171c0652 100644 --- a/src/training/graph_group_sync.cu +++ b/src/training/graph_group_sync.cu @@ -60,7 +60,7 @@ void SyncGraphGroup::execute(Ptr<data::Batch> batch) { for(auto graph : graphs_) { int __size__ = min(shardSize_, totalSize); - auto paramsAlloc = New<TensorAllocator>(graph->getDevice()); + auto paramsAlloc = New<TensorAllocator>(graph->getBackend()); paramsAllocs_.push_back(paramsAlloc); paramsAlloc->reserveExact(3 * __size__ * sizeof(float)); @@ -87,7 +87,7 @@ void SyncGraphGroup::execute(Ptr<data::Batch> batch) { int __size__ = min(shardSize_, totalSize); totalSize -= __size__; Tensor paramAvg; - auto allocator = New<TensorAllocator>(graph->getDevice()); + auto allocator = New<TensorAllocator>(graph->getBackend()); allocator->reserveExact(__size__ * sizeof(float)); allocator->allocate(paramAvg, {1, __size__}); diff --git a/src/training/sparse_tensor.cu b/src/training/sparse_tensor.cu index 5d32b16f..aafafa97 100644 --- a/src/training/sparse_tensor.cu +++ b/src/training/sparse_tensor.cu @@ -45,9 +45,9 @@ __global__ void gFindSubtensor(int* indices, resultEnd[0] = idx; } -SparseTensorBase::SparseTensorBase(int capacity, DeviceId deviceId) -: deviceId_(deviceId), capacity_(capacity) { - cudaSetDevice(deviceId_.no); +SparseTensorBase::SparseTensorBase(int capacity, Ptr<Backend> backend) +: backend_(backend), capacity_(capacity) { + cudaSetDevice(backend_->getDevice().no); CUDA_CHECK(cudaMalloc(&data_, sizeof(float) * capacity)); CUDA_CHECK(cudaMalloc(&indices_, sizeof(int) * capacity)); @@ -58,8 +58,8 @@ SparseTensorBase::SparseTensorBase(int capacity, DeviceId deviceId) SparseTensorBase::SparseTensorBase(float* data, int* indices, int size, - DeviceId deviceId) -: deviceId_(deviceId) { + Ptr<Backend> backend) +: backend_(backend) { data_ = data; indices_ = indices; size_ = size; @@ -93,7 +93,7 @@ void SparseTensorBase::copyFrom(float* data, size_ = size; if(size == 0) return; - cudaSetDevice(deviceId_.no); + cudaSetDevice(backend_->getDevice().no); cudaMemcpy(data_, data, size * sizeof(float), cudaMemcpyDefault); if(!data_only) @@ -107,8 +107,8 @@ void SparseTensorBase::copyFrom(std::shared_ptr<SparseTensorBase> t, copyFrom(t->data(), t->indices(), t->size(), data_only); } -DeviceId SparseTensorBase::getDevice() { - return deviceId_; +Ptr<Backend> SparseTensorBase::getBackend() { + return backend_; } void SparseTensorBase::setSize(int size) { @@ -117,7 +117,7 @@ void SparseTensorBase::setSize(int size) { // return the dense representation of this tensor void SparseTensorBase::toDense(Tensor t, int offset) { - cudaSetDevice(deviceId_.no); + cudaSetDevice(backend_->getDevice().no); int threads = 512; int blocks = 1 + size_ / threads; t->set(0); @@ -127,7 +127,7 @@ void SparseTensorBase::toDense(Tensor t, int offset) { } void SparseTensorBase::scatterAdd(Tensor t, int offset) { - cudaSetDevice(deviceId_.no); + cudaSetDevice(backend_->getDevice().no); cudaStreamSynchronize(0); int threads = 512; int blocks = 1 + size_ / threads; @@ -139,7 +139,7 @@ void SparseTensorBase::scatterAdd(Tensor t, int offset) { std::shared_ptr<SparseTensorBase> SparseTensorBase::subtensor(int pos, int size, int idx) { - cudaSetDevice(deviceId_.no); + cudaSetDevice(backend_->getDevice().no); cudaStreamSynchronize(0); int* start = gstart_ + idx; int* end = gend_ + idx; @@ -165,6 +165,6 @@ std::shared_ptr<SparseTensorBase> SparseTensorBase::subtensor(int pos, int subtensorSize = std::max(0, endOffset - startOffset + 1); cudaStreamSynchronize(0); return std::shared_ptr<SparseTensorBase>(new SparseTensorBase( - data_ + startOffset, indices_ + startOffset, subtensorSize, deviceId_)); + data_ + startOffset, indices_ + startOffset, subtensorSize, backend_)); } } diff --git a/src/training/sparse_tensor.h b/src/training/sparse_tensor.h index 03fb53a9..9194748f 100644 --- a/src/training/sparse_tensor.h +++ b/src/training/sparse_tensor.h @@ -10,15 +10,15 @@ class SparseTensorBase : public std::enable_shared_from_this<SparseTensorBase> { int* indices_; int size_; int capacity_; - DeviceId deviceId_; + Ptr<Backend> backend_; int* d_is_unsorted; int* gstart_; int* gend_; public: - SparseTensorBase(int capacity, DeviceId deviceId); - SparseTensorBase(float* data, int* indices, int size, DeviceId deviceId); + SparseTensorBase(int capacity, Ptr<Backend> backend); + SparseTensorBase(float* data, int* indices, int size, Ptr<Backend> backend); ~SparseTensorBase() {} @@ -43,7 +43,7 @@ public: void scatterAdd(Tensor t, int offset = 0); std::shared_ptr<SparseTensorBase> subtensor(int pos, int size, int idx); - DeviceId getDevice(); + Ptr<Backend> getBackend(); void toDense(Tensor t, int offset); }; |